1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Diagnostics;
6 using System.Runtime;
7 using System.Runtime.InteropServices;
8 using System.Text;
9 
10 namespace System.Security
11 {
12     // SecureString attempts to provide a defense-in-depth solution.
13     //
14     // On Windows, this is done with several mechanisms:
15     // 1. keeping the data in unmanaged memory so that copies of it aren't implicitly made by the GC moving it around
16     // 2. zero'ing out that unmanaged memory so that the string is reliably removed from memory when done with it
17     // 3. encrypting the data while it's not being used (it's unencrypted to manipulate and use it)
18     //
19     // On Unix, we do 1 and 2, but we don't do 3 as there's no CryptProtectData equivalent.
20 
21     public sealed partial class SecureString
22     {
23         private UnmanagedBuffer _buffer;
24 
SecureString(SecureString str)25         internal SecureString(SecureString str)
26         {
27             // Allocate enough space to store the provided string
28             EnsureCapacity(str._decryptedLength);
29             _decryptedLength = str._decryptedLength;
30 
31             // Copy the string into the newly allocated space
32             if (_decryptedLength > 0)
33             {
34                 UnmanagedBuffer.Copy(str._buffer, _buffer, (ulong)(str._decryptedLength * sizeof(char)));
35             }
36         }
37 
InitializeSecureString(char* value, int length)38         private unsafe void InitializeSecureString(char* value, int length)
39         {
40             // Allocate enough space to store the provided string
41             EnsureCapacity(length);
42             _decryptedLength = length;
43             if (length == 0)
44             {
45                 return;
46             }
47 
48             // Copy the string into the newly allocated space
49             byte* ptr = null;
50             try
51             {
52                 _buffer.AcquirePointer(ref ptr);
53                 Buffer.MemoryCopy(value, ptr, _buffer.ByteLength, (ulong)(length * sizeof(char)));
54             }
55             finally
56             {
57                 if (ptr != null)
58                 {
59                     _buffer.ReleasePointer();
60                 }
61             }
62         }
63 
DisposeCore()64         private void DisposeCore()
65         {
66             if (_buffer != null && !_buffer.IsInvalid)
67             {
68                 _buffer.Dispose();
69                 _buffer = null;
70             }
71         }
72 
EnsureNotDisposed()73         private void EnsureNotDisposed()
74         {
75             if (_buffer == null)
76             {
77                 throw new ObjectDisposedException(GetType().Name);
78             }
79         }
80 
ClearCore()81         private void ClearCore()
82         {
83             _decryptedLength = 0;
84             _buffer.Clear();
85         }
86 
AppendCharCore(char c)87         private unsafe void AppendCharCore(char c)
88         {
89             // Make sure we have enough space for the new character, then write it at the end.
90             EnsureCapacity(_decryptedLength + 1);
91             _buffer.Write((ulong)(_decryptedLength * sizeof(char)), c);
92             _decryptedLength++;
93         }
94 
InsertAtCore(int index, char c)95         private unsafe void InsertAtCore(int index, char c)
96         {
97             // Make sure we have enough space for the new character, then shift all of the characters above it and insert it.
98             EnsureCapacity(_decryptedLength + 1);
99             byte* ptr = null;
100             try
101             {
102                 _buffer.AcquirePointer(ref ptr);
103                 ptr += index * sizeof(char);
104                 long bytesToShift = (_decryptedLength - index) * sizeof(char);
105                 Buffer.MemoryCopy(ptr, ptr + sizeof(char), bytesToShift, bytesToShift);
106                 *((char*)ptr) = c;
107                 ++_decryptedLength;
108             }
109             finally
110             {
111                 if (ptr != null)
112                 {
113                     _buffer.ReleasePointer();
114                 }
115             }
116         }
117 
RemoveAtCore(int index)118         private unsafe void RemoveAtCore(int index)
119         {
120             // Shift down all values above the specified index, then null out the empty space at the end.
121             byte* ptr = null;
122             try
123             {
124                 _buffer.AcquirePointer(ref ptr);
125                 ptr += index * sizeof(char);
126                 long bytesToShift = (_decryptedLength - index - 1) * sizeof(char);
127                 Buffer.MemoryCopy(ptr + sizeof(char), ptr, bytesToShift, bytesToShift);
128                 *((char*)(ptr + bytesToShift)) = (char)0;
129                 --_decryptedLength;
130             }
131             finally
132             {
133                 if (ptr != null)
134                 {
135                     _buffer.ReleasePointer();
136                 }
137             }
138         }
139 
SetAtCore(int index, char c)140         private void SetAtCore(int index, char c)
141         {
142             // Overwrite the character at the specified index
143             _buffer.Write((ulong)(index * sizeof(char)), c);
144         }
145 
MarshalToBSTR()146         internal unsafe IntPtr MarshalToBSTR()
147         {
148             int length = _decryptedLength;
149             IntPtr ptr = IntPtr.Zero;
150             IntPtr result = IntPtr.Zero;
151             byte* bufferPtr = null;
152 
153             try
154             {
155                 _buffer.AcquirePointer(ref bufferPtr);
156                 int resultByteLength = (length + 1) * sizeof(char);
157 
158                 ptr = PInvokeMarshal.AllocBSTR(length);
159 
160                 Buffer.MemoryCopy(bufferPtr, (byte*)ptr, resultByteLength, length * sizeof(char));
161 
162                 result = ptr;
163             }
164             finally
165             {
166                 // If we failed for any reason, free the new buffer
167                 if (result == IntPtr.Zero && ptr != IntPtr.Zero)
168                 {
169                     RuntimeImports.RhZeroMemory(ptr, (UIntPtr)(length * sizeof(char)));
170                     PInvokeMarshal.FreeBSTR(ptr);
171                 }
172 
173                 if (bufferPtr != null)
174                 {
175                     _buffer.ReleasePointer();
176                 }
177             }
178             return result;
179         }
180 
MarshalToStringCore(bool globalAlloc, bool unicode)181         internal unsafe IntPtr MarshalToStringCore(bool globalAlloc, bool unicode)
182         {
183             int length = _decryptedLength;
184 
185             byte* bufferPtr = null;
186             IntPtr stringPtr = IntPtr.Zero, result = IntPtr.Zero;
187             try
188             {
189                 _buffer.AcquirePointer(ref bufferPtr);
190                 if (unicode)
191                 {
192                     int resultLength = (length + 1) * sizeof(char);
193                     stringPtr = globalAlloc ? Marshal.AllocHGlobal(resultLength) : Marshal.AllocCoTaskMem(resultLength);
194                     Buffer.MemoryCopy(
195                         source: bufferPtr,
196                         destination: (byte*)stringPtr.ToPointer(),
197                         destinationSizeInBytes: resultLength,
198                         sourceBytesToCopy: length * sizeof(char));
199                     *(length + (char*)stringPtr) = '\0';
200                 }
201                 else
202                 {
203                     int resultLength = Encoding.UTF8.GetByteCount((char*)bufferPtr, length) + 1;
204                     stringPtr = globalAlloc ? Marshal.AllocHGlobal(resultLength) : Marshal.AllocCoTaskMem(resultLength);
205                     int encodedLength = Encoding.UTF8.GetBytes((char*)bufferPtr, length, (byte*)stringPtr, resultLength);
206                     Debug.Assert(encodedLength + 1 == resultLength, $"Expected encoded length to match result, got {encodedLength} != {resultLength}");
207                     *(resultLength - 1 + (byte*)stringPtr) = 0;
208                 }
209 
210                 result = stringPtr;
211             }
212             finally
213             {
214                 // If there was a failure, such that result isn't initialized,
215                 // release the string if we had one.
216                 if (stringPtr != IntPtr.Zero && result == IntPtr.Zero)
217                 {
218                     RuntimeImports.RhZeroMemory(stringPtr, (UIntPtr)(length * sizeof(char)));
219                     MarshalFree(stringPtr, globalAlloc);
220                 }
221 
222                 if (bufferPtr != null)
223                 {
224                     _buffer.ReleasePointer();
225                 }
226             }
227 
228             return result;
229         }
230 
231         // -----------------------------
232         // ---- PAL layer ends here ----
233         // -----------------------------
234 
EnsureCapacity(int capacity)235         private void EnsureCapacity(int capacity)
236         {
237             // Make sure the requested capacity doesn't exceed SecureString's defined limit
238             if (capacity > MaxLength)
239             {
240                 throw new ArgumentOutOfRangeException(nameof(capacity), SR.ArgumentOutOfRange_Capacity);
241             }
242 
243             // If we already have enough space allocated, we're done
244             if (_buffer != null && (capacity * sizeof(char)) <= (int)_buffer.ByteLength)
245             {
246                 return;
247             }
248 
249             // We need more space, so allocate a new buffer, copy all our data into it,
250             // and then swap the new for the old.
251             UnmanagedBuffer newBuffer = UnmanagedBuffer.Allocate(capacity * sizeof(char));
252             if (_buffer != null)
253             {
254                 UnmanagedBuffer.Copy(_buffer, newBuffer, _buffer.ByteLength);
255                 _buffer.Dispose();
256             }
257             _buffer = newBuffer;
258         }
259 
260         /// <summary>SafeBuffer for managing memory meant to be kept confidential.</summary>
261         private sealed class UnmanagedBuffer : SafeBuffer
262         {
UnmanagedBuffer()263             internal UnmanagedBuffer() : base(true) { }
264 
Allocate(int bytes)265             internal static UnmanagedBuffer Allocate(int bytes)
266             {
267                 Debug.Assert(bytes >= 0);
268                 UnmanagedBuffer buffer = new UnmanagedBuffer();
269                 buffer.SetHandle(Marshal.AllocHGlobal(bytes));
270                 buffer.Initialize((ulong)bytes);
271                 return buffer;
272             }
273 
Clear()274             internal unsafe void Clear()
275             {
276                 byte* ptr = null;
277                 try
278                 {
279                     AcquirePointer(ref ptr);
280                     RuntimeImports.RhZeroMemory((IntPtr)ptr, (UIntPtr)ByteLength);
281                 }
282                 finally
283                 {
284                     if (ptr != null)
285                     {
286                         ReleasePointer();
287                     }
288                 }
289             }
290 
Copy(UnmanagedBuffer source, UnmanagedBuffer destination, ulong bytesLength)291             internal static unsafe void Copy(UnmanagedBuffer source, UnmanagedBuffer destination, ulong bytesLength)
292             {
293                 if (bytesLength == 0)
294                 {
295                     return;
296                 }
297 
298                 byte* srcPtr = null, dstPtr = null;
299                 try
300                 {
301                     source.AcquirePointer(ref srcPtr);
302                     destination.AcquirePointer(ref dstPtr);
303                     Buffer.MemoryCopy(srcPtr, dstPtr, destination.ByteLength, bytesLength);
304                 }
305                 finally
306                 {
307                     if (dstPtr != null)
308                     {
309                         destination.ReleasePointer();
310                     }
311                     if (srcPtr != null)
312                     {
313                         source.ReleasePointer();
314                     }
315                 }
316             }
317 
ReleaseHandle()318             protected override unsafe bool ReleaseHandle()
319             {
320                 Marshal.FreeHGlobal(handle);
321                 return true;
322             }
323         }
324     }
325 }
326