1 // Licensed to the Apache Software Foundation (ASF) under one or more
2 // contributor license agreements. See the NOTICE file distributed with
3 // this work for additional information regarding copyright ownership.
4 // The ASF licenses this file to You under the Apache License, Version 2.0
5 // (the "License"); you may not use this file except in compliance with
6 // the License.  You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 using Apache.Arrow.Memory;
17 using System;
18 using System.Buffers;
19 using System.IO;
20 using System.Linq;
21 using System.Threading;
22 using System.Threading.Tasks;
23 
24 namespace Apache.Arrow.Ipc
25 {
26     internal sealed class ArrowFileReaderImplementation : ArrowStreamReaderImplementation
27     {
28         public bool IsFileValid { get; private set; }
29 
30         /// <summary>
31         /// When using GetNextRecordBatch this value
32         /// is to remember what index is next
33         /// </summary>
34         private int _recordBatchIndex;
35 
36         private ArrowFooter _footer;
37 
ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen)38         public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen)
39             : base(stream, allocator, leaveOpen)
40         {
41         }
42 
RecordBatchCountAsync()43         public async ValueTask<int> RecordBatchCountAsync()
44         {
45             if (!HasReadSchema)
46             {
47                 await ReadSchemaAsync().ConfigureAwait(false);
48             }
49 
50             return _footer.RecordBatchCount;
51         }
52 
ReadSchemaAsync()53         protected override async ValueTask ReadSchemaAsync()
54         {
55             if (HasReadSchema)
56             {
57                 return;
58             }
59 
60             await ValidateFileAsync().ConfigureAwait(false);
61 
62             int footerLength = 0;
63             await ArrayPool<byte>.Shared.RentReturnAsync(4, async (buffer) =>
64             {
65                 BaseStream.Position = GetFooterLengthPosition();
66 
67                 int bytesRead = await BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false);
68                 EnsureFullRead(buffer, bytesRead);
69 
70                 footerLength = ReadFooterLength(buffer);
71             }).ConfigureAwait(false);
72 
73             await ArrayPool<byte>.Shared.RentReturnAsync(footerLength, async (buffer) =>
74             {
75                 long footerStartPosition = GetFooterLengthPosition() - footerLength;
76 
77                 BaseStream.Position = footerStartPosition;
78 
79                 int bytesRead = await BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false);
80                 EnsureFullRead(buffer, bytesRead);
81 
82                 ReadSchema(buffer);
83             }).ConfigureAwait(false);
84         }
85 
ReadSchema()86         protected override void ReadSchema()
87         {
88             if (HasReadSchema)
89             {
90                 return;
91             }
92 
93             ValidateFile();
94 
95             int footerLength = 0;
96             ArrayPool<byte>.Shared.RentReturn(4, (buffer) =>
97             {
98                 BaseStream.Position = GetFooterLengthPosition();
99 
100                 int bytesRead = BaseStream.ReadFullBuffer(buffer);
101                 EnsureFullRead(buffer, bytesRead);
102 
103                 footerLength = ReadFooterLength(buffer);
104             });
105 
106             ArrayPool<byte>.Shared.RentReturn(footerLength, (buffer) =>
107             {
108                 long footerStartPosition = GetFooterLengthPosition() - footerLength;
109 
110                 BaseStream.Position = footerStartPosition;
111 
112                 int bytesRead = BaseStream.ReadFullBuffer(buffer);
113                 EnsureFullRead(buffer, bytesRead);
114 
115                 ReadSchema(buffer);
116             });
117         }
118 
GetFooterLengthPosition()119         private long GetFooterLengthPosition()
120         {
121             return BaseStream.Length - ArrowFileConstants.Magic.Length - 4;
122         }
123 
ReadFooterLength(Memory<byte> buffer)124         private static int ReadFooterLength(Memory<byte> buffer)
125         {
126             int footerLength = BitUtility.ReadInt32(buffer);
127 
128             if (footerLength <= 0)
129                 throw new InvalidDataException(
130                     $"Footer length has invalid size <{footerLength}>");
131 
132             return footerLength;
133         }
134 
ReadSchema(Memory<byte> buffer)135         private void ReadSchema(Memory<byte> buffer)
136         {
137             // Deserialize the footer from the footer flatbuffer
138             _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)));
139 
140             Schema = _footer.Schema;
141         }
142 
ReadRecordBatchAsync(int index, CancellationToken cancellationToken)143         public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, CancellationToken cancellationToken)
144         {
145             await ReadSchemaAsync().ConfigureAwait(false);
146 
147             if (index >= _footer.RecordBatchCount)
148             {
149                 throw new ArgumentOutOfRangeException(nameof(index));
150             }
151 
152             Block block = _footer.GetRecordBatchBlock(index);
153 
154             BaseStream.Position = block.Offset;
155 
156             return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false);
157         }
158 
ReadRecordBatch(int index)159         public RecordBatch ReadRecordBatch(int index)
160         {
161             ReadSchema();
162 
163             if (index >= _footer.RecordBatchCount)
164             {
165                 throw new ArgumentOutOfRangeException(nameof(index));
166             }
167 
168             Block block = _footer.GetRecordBatchBlock(index);
169 
170             BaseStream.Position = block.Offset;
171 
172             return ReadRecordBatch();
173         }
174 
ReadNextRecordBatchAsync(CancellationToken cancellationToken)175         public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
176         {
177             await ReadSchemaAsync().ConfigureAwait(false);
178 
179             if (_recordBatchIndex >= _footer.RecordBatchCount)
180             {
181                 return null;
182             }
183 
184             RecordBatch result = await ReadRecordBatchAsync(_recordBatchIndex, cancellationToken).ConfigureAwait(false);
185             _recordBatchIndex++;
186 
187             return result;
188         }
189 
ReadNextRecordBatch()190         public override RecordBatch ReadNextRecordBatch()
191         {
192             ReadSchema();
193 
194             if (_recordBatchIndex >= _footer.RecordBatchCount)
195             {
196                 return null;
197             }
198 
199             RecordBatch result = ReadRecordBatch(_recordBatchIndex);
200             _recordBatchIndex++;
201 
202             return result;
203         }
204 
205         /// <summary>
206         /// Check if file format is valid. If it's valid don't run the validation again.
207         /// </summary>
ValidateFileAsync()208         private async ValueTask ValidateFileAsync()
209         {
210             if (IsFileValid)
211             {
212                 return;
213             }
214 
215             await ValidateMagicAsync().ConfigureAwait(false);
216 
217             IsFileValid = true;
218         }
219 
220         /// <summary>
221         /// Check if file format is valid. If it's valid don't run the validation again.
222         /// </summary>
ValidateFile()223         private void ValidateFile()
224         {
225             if (IsFileValid)
226             {
227                 return;
228             }
229 
230             ValidateMagic();
231 
232             IsFileValid = true;
233         }
234 
ValidateMagicAsync()235         private async ValueTask ValidateMagicAsync()
236         {
237             long startingPosition = BaseStream.Position;
238             int magicLength = ArrowFileConstants.Magic.Length;
239 
240             try
241             {
242                 await ArrayPool<byte>.Shared.RentReturnAsync(magicLength, async (buffer) =>
243                 {
244                     // Seek to the beginning of the stream
245                     BaseStream.Position = 0;
246 
247                     // Read beginning of stream
248                     await BaseStream.ReadAsync(buffer).ConfigureAwait(false);
249 
250                     VerifyMagic(buffer);
251 
252                     // Move stream position to magic-length bytes away from the end of the stream
253                     BaseStream.Position = BaseStream.Length - magicLength;
254 
255                     // Read the end of the stream
256                     await BaseStream.ReadAsync(buffer).ConfigureAwait(false);
257 
258                     VerifyMagic(buffer);
259                 }).ConfigureAwait(false);
260             }
261             finally
262             {
263                 BaseStream.Position = startingPosition;
264             }
265         }
266 
ValidateMagic()267         private void ValidateMagic()
268         {
269             long startingPosition = BaseStream.Position;
270             int magicLength = ArrowFileConstants.Magic.Length;
271 
272             try
273             {
274                 ArrayPool<byte>.Shared.RentReturn(magicLength, buffer =>
275                 {
276                     // Seek to the beginning of the stream
277                     BaseStream.Position = 0;
278 
279                     // Read beginning of stream
280                     BaseStream.Read(buffer);
281 
282                     VerifyMagic(buffer);
283 
284                     // Move stream position to magic-length bytes away from the end of the stream
285                     BaseStream.Position = BaseStream.Length - magicLength;
286 
287                     // Read the end of the stream
288                     BaseStream.Read(buffer);
289 
290                     VerifyMagic(buffer);
291                 });
292             }
293             finally
294             {
295                 BaseStream.Position = startingPosition;
296             }
297         }
298 
VerifyMagic(Memory<byte> buffer)299         private void VerifyMagic(Memory<byte> buffer)
300         {
301             if (!ArrowFileConstants.Magic.AsSpan().SequenceEqual(buffer.Span))
302             {
303                 throw new InvalidDataException(
304                     $"Invalid magic at offset <{BaseStream.Position}>");
305             }
306         }
307     }
308 }
309