1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Runtime.CompilerServices;
6 using System.Runtime.InteropServices;
7 
8 #if !netstandard
9 using Internal.Runtime.CompilerServices;
10 #endif
11 
12 namespace System.Buffers.Text
13 {
14     public static partial class Base64
15     {
16         /// <summary>
17         /// Decode the span of UTF-8 encoded text represented as base 64 into binary data.
18         /// If the input is not a multiple of 4, it will decode as much as it can, to the closest multiple of 4.
19         ///
20         /// <param name="utf8">The input span which contains UTF-8 encoded text in base 64 that needs to be decoded.</param>
21         /// <param name="bytes">The output span which contains the result of the operation, i.e. the decoded binary data.</param>
22         /// <param name="consumed">The number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary.</param>
23         /// <param name="written">The number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary.</param>
24         /// <param name="isFinalBlock">True (default) when the input span contains the entire data to decode.
25         /// Set to false only if it is known that the input span contains partial data with more data to follow.</param>
26         /// <returns>It returns the OperationStatus enum values:
27         /// - Done - on successful processing of the entire input span
28         /// - DestinationTooSmall - if there is not enough space in the output span to fit the decoded input
29         /// - NeedMoreData - only if isFinalBlock is false and the input is not a multiple of 4, otherwise the partial input would be considered as InvalidData
30         /// - InvalidData - if the input contains bytes outside of the expected base 64 range, or if it contains invalid/more than two padding characters,
31         ///   or if the input is incomplete (i.e. not a multiple of 4) and isFinalBlock is true.</returns>
32         /// </summary>
DecodeFromUtf8(ReadOnlySpan<byte> utf8, Span<byte> bytes, out int consumed, out int written, bool isFinalBlock = true)33         public static OperationStatus DecodeFromUtf8(ReadOnlySpan<byte> utf8, Span<byte> bytes, out int consumed, out int written, bool isFinalBlock = true)
34         {
35             ref byte srcBytes = ref MemoryMarshal.GetReference(utf8);
36             ref byte destBytes = ref MemoryMarshal.GetReference(bytes);
37 
38             int srcLength = utf8.Length & ~0x3;  // only decode input up to the closest multiple of 4.
39             int destLength = bytes.Length;
40 
41             int sourceIndex = 0;
42             int destIndex = 0;
43 
44             if (utf8.Length == 0)
45                 goto DoneExit;
46 
47             ref sbyte decodingMap = ref s_decodingMap[0];
48 
49             // Last bytes could have padding characters, so process them separately and treat them as valid only if isFinalBlock is true
50             // if isFinalBlock is false, padding characters are considered invalid
51             int skipLastChunk = isFinalBlock ? 4 : 0;
52 
53             int maxSrcLength = 0;
54             if (destLength >= GetMaxDecodedFromUtf8Length(srcLength))
55             {
56                 maxSrcLength = srcLength - skipLastChunk;
57             }
58             else
59             {
60                 // This should never overflow since destLength here is less than int.MaxValue / 4 * 3 (i.e. 1610612733)
61                 // Therefore, (destLength / 3) * 4 will always be less than 2147483641
62                 maxSrcLength = (destLength / 3) * 4;
63             }
64 
65             while (sourceIndex < maxSrcLength)
66             {
67                 int result = Decode(ref Unsafe.Add(ref srcBytes, sourceIndex), ref decodingMap);
68                 if (result < 0)
69                     goto InvalidExit;
70                 WriteThreeLowOrderBytes(ref Unsafe.Add(ref destBytes, destIndex), result);
71                 destIndex += 3;
72                 sourceIndex += 4;
73             }
74 
75             if (maxSrcLength != srcLength - skipLastChunk)
76                 goto DestinationSmallExit;
77 
78             // If input is less than 4 bytes, srcLength == sourceIndex == 0
79             // If input is not a multiple of 4, sourceIndex == srcLength != 0
80             if (sourceIndex == srcLength)
81             {
82                 if (isFinalBlock)
83                     goto InvalidExit;
84                 goto NeedMoreExit;
85             }
86 
87             // if isFinalBlock is false, we will never reach this point
88 
89             int i0 = Unsafe.Add(ref srcBytes, srcLength - 4);
90             int i1 = Unsafe.Add(ref srcBytes, srcLength - 3);
91             int i2 = Unsafe.Add(ref srcBytes, srcLength - 2);
92             int i3 = Unsafe.Add(ref srcBytes, srcLength - 1);
93 
94             i0 = Unsafe.Add(ref decodingMap, i0);
95             i1 = Unsafe.Add(ref decodingMap, i1);
96 
97             i0 <<= 18;
98             i1 <<= 12;
99 
100             i0 |= i1;
101 
102             if (i3 != EncodingPad)
103             {
104                 i2 = Unsafe.Add(ref decodingMap, i2);
105                 i3 = Unsafe.Add(ref decodingMap, i3);
106 
107                 i2 <<= 6;
108 
109                 i0 |= i3;
110                 i0 |= i2;
111 
112                 if (i0 < 0)
113                     goto InvalidExit;
114                 if (destIndex > destLength - 3)
115                     goto DestinationSmallExit;
116                 WriteThreeLowOrderBytes(ref Unsafe.Add(ref destBytes, destIndex), i0);
117                 destIndex += 3;
118             }
119             else if (i2 != EncodingPad)
120             {
121                 i2 = Unsafe.Add(ref decodingMap, i2);
122 
123                 i2 <<= 6;
124 
125                 i0 |= i2;
126 
127                 if (i0 < 0)
128                     goto InvalidExit;
129                 if (destIndex > destLength - 2)
130                     goto DestinationSmallExit;
131                 Unsafe.Add(ref destBytes, destIndex) = (byte)(i0 >> 16);
132                 Unsafe.Add(ref destBytes, destIndex + 1) = (byte)(i0 >> 8);
133                 destIndex += 2;
134             }
135             else
136             {
137                 if (i0 < 0)
138                     goto InvalidExit;
139                 if (destIndex > destLength - 1)
140                     goto DestinationSmallExit;
141                 Unsafe.Add(ref destBytes, destIndex) = (byte)(i0 >> 16);
142                 destIndex += 1;
143             }
144 
145             sourceIndex += 4;
146 
147             if (srcLength != utf8.Length)
148                 goto InvalidExit;
149 
150         DoneExit:
151             consumed = sourceIndex;
152             written = destIndex;
153             return OperationStatus.Done;
154 
155         DestinationSmallExit:
156             if (srcLength != utf8.Length && isFinalBlock)
157                 goto InvalidExit; // if input is not a multiple of 4, and there is no more data, return invalid data instead
158             consumed = sourceIndex;
159             written = destIndex;
160             return OperationStatus.DestinationTooSmall;
161 
162         NeedMoreExit:
163             consumed = sourceIndex;
164             written = destIndex;
165             return OperationStatus.NeedMoreData;
166 
167         InvalidExit:
168             consumed = sourceIndex;
169             written = destIndex;
170             return OperationStatus.InvalidData;
171         }
172 
173         /// <summary>
174         /// Returns the maximum length (in bytes) of the result if you were to deocde base 64 encoded text within a byte span of size "length".
175         /// </summary>
176         /// <exception cref="System.ArgumentOutOfRangeException">
177         /// Thrown when the specified <paramref name="length"/> is less than 0.
178         /// </exception>
179         [MethodImpl(MethodImplOptions.AggressiveInlining)]
GetMaxDecodedFromUtf8Length(int length)180         public static int GetMaxDecodedFromUtf8Length(int length)
181         {
182             if (length < 0)
183                 ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length);
184 
185             return (length >> 2) * 3;
186         }
187 
188         /// <summary>
189         /// Decode the span of UTF-8 encoded text in base 64 (in-place) into binary data.
190         /// The decoded binary output is smaller than the text data contained in the input (the operation deflates the data).
191         /// If the input is not a multiple of 4, it will not decode any.
192         ///
193         /// <param name="buffer">The input span which contains the base 64 text data that needs to be decoded.</param>
194         /// <param name="written">The number of bytes written into the buffer.</param>
195         /// <returns>It returns the OperationStatus enum values:
196         /// - Done - on successful processing of the entire input span
197         /// - InvalidData - if the input contains bytes outside of the expected base 64 range, or if it contains invalid/more than two padding characters,
198         ///   or if the input is incomplete (i.e. not a multiple of 4).
199         /// It does not return DestinationTooSmall since that is not possible for base 64 decoding.
200         /// It does not return NeedMoreData since this method tramples the data in the buffer and
201         /// hence can only be called once with all the data in the buffer.</returns>
202         /// </summary>
DecodeFromUtf8InPlace(Span<byte> buffer, out int written)203         public static OperationStatus DecodeFromUtf8InPlace(Span<byte> buffer, out int written)
204         {
205             int bufferLength = buffer.Length;
206             int sourceIndex = 0;
207             int destIndex = 0;
208 
209             // only decode input if it is a multiple of 4
210             if (bufferLength != ((bufferLength >> 2) * 4))
211                 goto InvalidExit;
212             if (bufferLength == 0)
213                 goto DoneExit;
214 
215             ref byte bufferBytes = ref MemoryMarshal.GetReference(buffer);
216 
217             ref sbyte decodingMap = ref s_decodingMap[0];
218 
219             while (sourceIndex < bufferLength - 4)
220             {
221                 int result = Decode(ref Unsafe.Add(ref bufferBytes, sourceIndex), ref decodingMap);
222                 if (result < 0)
223                     goto InvalidExit;
224                 WriteThreeLowOrderBytes(ref Unsafe.Add(ref bufferBytes, destIndex), result);
225                 destIndex += 3;
226                 sourceIndex += 4;
227             }
228 
229             int i0 = Unsafe.Add(ref bufferBytes, bufferLength - 4);
230             int i1 = Unsafe.Add(ref bufferBytes, bufferLength - 3);
231             int i2 = Unsafe.Add(ref bufferBytes, bufferLength - 2);
232             int i3 = Unsafe.Add(ref bufferBytes, bufferLength - 1);
233 
234             i0 = Unsafe.Add(ref decodingMap, i0);
235             i1 = Unsafe.Add(ref decodingMap, i1);
236 
237             i0 <<= 18;
238             i1 <<= 12;
239 
240             i0 |= i1;
241 
242             if (i3 != EncodingPad)
243             {
244                 i2 = Unsafe.Add(ref decodingMap, i2);
245                 i3 = Unsafe.Add(ref decodingMap, i3);
246 
247                 i2 <<= 6;
248 
249                 i0 |= i3;
250                 i0 |= i2;
251 
252                 if (i0 < 0)
253                     goto InvalidExit;
254                 WriteThreeLowOrderBytes(ref Unsafe.Add(ref bufferBytes, destIndex), i0);
255                 destIndex += 3;
256             }
257             else if (i2 != EncodingPad)
258             {
259                 i2 = Unsafe.Add(ref decodingMap, i2);
260 
261                 i2 <<= 6;
262 
263                 i0 |= i2;
264 
265                 if (i0 < 0)
266                     goto InvalidExit;
267                 Unsafe.Add(ref bufferBytes, destIndex) = (byte)(i0 >> 16);
268                 Unsafe.Add(ref bufferBytes, destIndex + 1) = (byte)(i0 >> 8);
269                 destIndex += 2;
270             }
271             else
272             {
273                 if (i0 < 0)
274                     goto InvalidExit;
275                 Unsafe.Add(ref bufferBytes, destIndex) = (byte)(i0 >> 16);
276                 destIndex += 1;
277             }
278 
279         DoneExit:
280             written = destIndex;
281             return OperationStatus.Done;
282 
283         InvalidExit:
284             written = destIndex;
285             return OperationStatus.InvalidData;
286         }
287 
288         [MethodImpl(MethodImplOptions.AggressiveInlining)]
Decode(ref byte encodedBytes, ref sbyte decodingMap)289         private static int Decode(ref byte encodedBytes, ref sbyte decodingMap)
290         {
291             int i0 = encodedBytes;
292             int i1 = Unsafe.Add(ref encodedBytes, 1);
293             int i2 = Unsafe.Add(ref encodedBytes, 2);
294             int i3 = Unsafe.Add(ref encodedBytes, 3);
295 
296             i0 = Unsafe.Add(ref decodingMap, i0);
297             i1 = Unsafe.Add(ref decodingMap, i1);
298             i2 = Unsafe.Add(ref decodingMap, i2);
299             i3 = Unsafe.Add(ref decodingMap, i3);
300 
301             i0 <<= 18;
302             i1 <<= 12;
303             i2 <<= 6;
304 
305             i0 |= i3;
306             i1 |= i2;
307 
308             i0 |= i1;
309             return i0;
310         }
311 
312         [MethodImpl(MethodImplOptions.AggressiveInlining)]
WriteThreeLowOrderBytes(ref byte destination, int value)313         private static void WriteThreeLowOrderBytes(ref byte destination, int value)
314         {
315             destination = (byte)(value >> 16);
316             Unsafe.Add(ref destination, 1) = (byte)(value >> 8);
317             Unsafe.Add(ref destination, 2) = (byte)value;
318         }
319 
320         // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests)
321         private static readonly sbyte[] s_decodingMap = {
322             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
323             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
324             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,         //62 is placed at index 43 (for +), 63 at index 47 (for /)
325             52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,         //52-61 are placed at index 48-57 (for 0-9), 64 at index 61 (for =)
326             -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
327             15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,         //0-25 are placed at index 65-90 (for A-Z)
328             -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
329             41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,         //26-51 are placed at index 97-122 (for a-z)
330             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,         // Bytes over 122 ('z') are invalid and cannot be decoded
331             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,         // Hence, padding the map with 255, which indicates invalid input
332             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
333             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
334             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
335             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
336             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
337             -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
338         };
339     }
340 }
341