1 /*
2   KeePass Password Safe - The Open-Source Password Manager
3   Copyright (C) 2003-2021 Dominik Reichl <dominik.reichl@t-online.de>
4 
5   This program is free software; you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation; either version 2 of the License, or
8   (at your option) any later version.
9 
10   This program is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14 
15   You should have received a copy of the GNU General Public License
16   along with this program; if not, write to the Free Software
17   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
18 */
19 
20 using System;
21 using System.Collections.Generic;
22 using System.Diagnostics;
23 using System.IO;
24 using System.Text;
25 
26 using KeePassLib.Resources;
27 using KeePassLib.Utility;
28 
29 namespace KeePassLib.Collections
30 {
31 	public class VariantDictionary : ICloneable
32 	{
33 		private const ushort VdVersion = 0x0100;
34 		private const ushort VdmCritical = 0xFF00;
35 		private const ushort VdmInfo = 0x00FF;
36 
37 		private Dictionary<string, object> m_d = new Dictionary<string, object>();
38 
39 		private enum VdType : byte
40 		{
41 			None = 0,
42 
43 			// Byte = 0x02,
44 			// UInt16 = 0x03,
45 			UInt32 = 0x04,
46 			UInt64 = 0x05,
47 
48 			// Signed mask: 0x08
49 			Bool = 0x08,
50 			// SByte = 0x0A,
51 			// Int16 = 0x0B,
52 			Int32 = 0x0C,
53 			Int64 = 0x0D,
54 
55 			// Float = 0x10,
56 			// Double = 0x11,
57 			// Decimal = 0x12,
58 
59 			// Char = 0x17, // 16-bit Unicode character
60 			String = 0x18,
61 
62 			// Array mask: 0x40
63 			ByteArray = 0x42
64 		}
65 
66 		public int Count
67 		{
68 			get { return m_d.Count; }
69 		}
70 
VariantDictionary()71 		public VariantDictionary()
72 		{
73 			Debug.Assert((VdmCritical & VdmInfo) == ushort.MinValue);
74 			Debug.Assert((VdmCritical | VdmInfo) == ushort.MaxValue);
75 		}
76 
Get(string strName, out T t)77 		private bool Get<T>(string strName, out T t)
78 		{
79 			t = default(T);
80 
81 			if(string.IsNullOrEmpty(strName)) { Debug.Assert(false); return false; }
82 
83 			object o;
84 			if(!m_d.TryGetValue(strName, out o)) return false; // No assert
85 
86 			if(o == null) { Debug.Assert(false); return false; }
87 			if(o.GetType() != typeof(T)) { Debug.Assert(false); return false; }
88 
89 			t = (T)o;
90 			return true;
91 		}
92 
93 		private void SetStruct<T>(string strName, T t)
94 			where T : struct
95 		{
96 			if(string.IsNullOrEmpty(strName)) { Debug.Assert(false); return; }
97 
98 #if DEBUG
99 			T tEx;
GetKeePassLib.Collections.VariantDictionary.__anon1100 			Get<T>(strName, out tEx); // Assert same type
101 #endif
102 
103 			m_d[strName] = t;
104 		}
105 
106 		private void SetRef<T>(string strName, T t)
107 			where T : class
108 		{
109 			if(string.IsNullOrEmpty(strName)) { Debug.Assert(false); return; }
110 			if(t == null) { Debug.Assert(false); return; }
111 
112 #if DEBUG
113 			T tEx;
Get(strName, out tEx)114 			Get<T>(strName, out tEx); // Assert same type
115 #endif
116 
117 			m_d[strName] = t;
118 		}
119 
Remove(string strName)120 		public bool Remove(string strName)
121 		{
122 			if(string.IsNullOrEmpty(strName)) { Debug.Assert(false); return false; }
123 
124 			return m_d.Remove(strName);
125 		}
126 
CopyTo(VariantDictionary d)127 		public void CopyTo(VariantDictionary d)
128 		{
129 			if(d == null) { Debug.Assert(false); return; }
130 
131 			// Do not clear the target
132 			foreach(KeyValuePair<string, object> kvp in m_d)
133 			{
134 				d.m_d[kvp.Key] = kvp.Value;
135 			}
136 		}
137 
GetTypeOf(string strName)138 		public Type GetTypeOf(string strName)
139 		{
140 			if(string.IsNullOrEmpty(strName)) { Debug.Assert(false); return null; }
141 
142 			object o;
143 			m_d.TryGetValue(strName, out o);
144 			if(o == null) return null; // No assert
145 
146 			return o.GetType();
147 		}
148 
GetUInt32(string strName, uint uDefault)149 		public uint GetUInt32(string strName, uint uDefault)
150 		{
151 			uint u;
152 			if(Get<uint>(strName, out u)) return u;
153 			return uDefault;
154 		}
155 
SetUInt32(string strName, uint uValue)156 		public void SetUInt32(string strName, uint uValue)
157 		{
158 			SetStruct<uint>(strName, uValue);
159 		}
160 
GetUInt64(string strName, ulong uDefault)161 		public ulong GetUInt64(string strName, ulong uDefault)
162 		{
163 			ulong u;
164 			if(Get<ulong>(strName, out u)) return u;
165 			return uDefault;
166 		}
167 
SetUInt64(string strName, ulong uValue)168 		public void SetUInt64(string strName, ulong uValue)
169 		{
170 			SetStruct<ulong>(strName, uValue);
171 		}
172 
GetBool(string strName, bool bDefault)173 		public bool GetBool(string strName, bool bDefault)
174 		{
175 			bool b;
176 			if(Get<bool>(strName, out b)) return b;
177 			return bDefault;
178 		}
179 
SetBool(string strName, bool bValue)180 		public void SetBool(string strName, bool bValue)
181 		{
182 			SetStruct<bool>(strName, bValue);
183 		}
184 
GetInt32(string strName, int iDefault)185 		public int GetInt32(string strName, int iDefault)
186 		{
187 			int i;
188 			if(Get<int>(strName, out i)) return i;
189 			return iDefault;
190 		}
191 
SetInt32(string strName, int iValue)192 		public void SetInt32(string strName, int iValue)
193 		{
194 			SetStruct<int>(strName, iValue);
195 		}
196 
GetInt64(string strName, long lDefault)197 		public long GetInt64(string strName, long lDefault)
198 		{
199 			long l;
200 			if(Get<long>(strName, out l)) return l;
201 			return lDefault;
202 		}
203 
SetInt64(string strName, long lValue)204 		public void SetInt64(string strName, long lValue)
205 		{
206 			SetStruct<long>(strName, lValue);
207 		}
208 
GetString(string strName)209 		public string GetString(string strName)
210 		{
211 			string str;
212 			Get<string>(strName, out str);
213 			return str;
214 		}
215 
SetString(string strName, string strValue)216 		public void SetString(string strName, string strValue)
217 		{
218 			SetRef<string>(strName, strValue);
219 		}
220 
GetByteArray(string strName)221 		public byte[] GetByteArray(string strName)
222 		{
223 			byte[] pb;
224 			Get<byte[]>(strName, out pb);
225 			return pb;
226 		}
227 
SetByteArray(string strName, byte[] pbValue)228 		public void SetByteArray(string strName, byte[] pbValue)
229 		{
230 			SetRef<byte[]>(strName, pbValue);
231 		}
232 
233 		/// <summary>
234 		/// Create a deep copy.
235 		/// </summary>
Clone()236 		public virtual object Clone()
237 		{
238 			VariantDictionary vdNew = new VariantDictionary();
239 
240 			foreach(KeyValuePair<string, object> kvp in m_d)
241 			{
242 				object o = kvp.Value;
243 				if(o == null) { Debug.Assert(false); continue; }
244 
245 				Type t = o.GetType();
246 				if(t == typeof(byte[]))
247 				{
248 					byte[] p = (byte[])o;
249 					byte[] pNew = new byte[p.Length];
250 					if(p.Length > 0) Array.Copy(p, pNew, p.Length);
251 
252 					o = pNew;
253 				}
254 
255 				vdNew.m_d[kvp.Key] = o;
256 			}
257 
258 			return vdNew;
259 		}
260 
Serialize(VariantDictionary p)261 		public static byte[] Serialize(VariantDictionary p)
262 		{
263 			if(p == null) { Debug.Assert(false); return null; }
264 
265 			byte[] pbRet;
266 			using(MemoryStream ms = new MemoryStream())
267 			{
268 				MemUtil.Write(ms, MemUtil.UInt16ToBytes(VdVersion));
269 
270 				foreach(KeyValuePair<string, object> kvp in p.m_d)
271 				{
272 					string strName = kvp.Key;
273 					if(string.IsNullOrEmpty(strName)) { Debug.Assert(false); continue; }
274 					byte[] pbName = StrUtil.Utf8.GetBytes(strName);
275 
276 					object o = kvp.Value;
277 					if(o == null) { Debug.Assert(false); continue; }
278 
279 					Type t = o.GetType();
280 					VdType vt = VdType.None;
281 					byte[] pbValue = null;
282 					if(t == typeof(uint))
283 					{
284 						vt = VdType.UInt32;
285 						pbValue = MemUtil.UInt32ToBytes((uint)o);
286 					}
287 					else if(t == typeof(ulong))
288 					{
289 						vt = VdType.UInt64;
290 						pbValue = MemUtil.UInt64ToBytes((ulong)o);
291 					}
292 					else if(t == typeof(bool))
293 					{
294 						vt = VdType.Bool;
295 						pbValue = new byte[1];
296 						pbValue[0] = ((bool)o ? (byte)1 : (byte)0);
297 					}
298 					else if(t == typeof(int))
299 					{
300 						vt = VdType.Int32;
301 						pbValue = MemUtil.Int32ToBytes((int)o);
302 					}
303 					else if(t == typeof(long))
304 					{
305 						vt = VdType.Int64;
306 						pbValue = MemUtil.Int64ToBytes((long)o);
307 					}
308 					else if(t == typeof(string))
309 					{
310 						vt = VdType.String;
311 						pbValue = StrUtil.Utf8.GetBytes((string)o);
312 					}
313 					else if(t == typeof(byte[]))
314 					{
315 						vt = VdType.ByteArray;
316 						pbValue = (byte[])o;
317 					}
318 					else { Debug.Assert(false); continue; } // Unknown type
319 
320 					ms.WriteByte((byte)vt);
321 					MemUtil.Write(ms, MemUtil.Int32ToBytes(pbName.Length));
322 					MemUtil.Write(ms, pbName);
323 					MemUtil.Write(ms, MemUtil.Int32ToBytes(pbValue.Length));
324 					MemUtil.Write(ms, pbValue);
325 				}
326 
327 				ms.WriteByte((byte)VdType.None);
328 				pbRet = ms.ToArray();
329 			}
330 
331 			return pbRet;
332 		}
333 
Deserialize(byte[] pb)334 		public static VariantDictionary Deserialize(byte[] pb)
335 		{
336 			if(pb == null) { Debug.Assert(false); return null; }
337 
338 			VariantDictionary d = new VariantDictionary();
339 			using(MemoryStream ms = new MemoryStream(pb, false))
340 			{
341 				ushort uVersion = MemUtil.BytesToUInt16(MemUtil.Read(ms, 2));
342 				if((uVersion & VdmCritical) > (VdVersion & VdmCritical))
343 					throw new FormatException(KLRes.FileNewVerReq);
344 
345 				while(true)
346 				{
347 					int iType = ms.ReadByte();
348 					if(iType < 0) throw new EndOfStreamException(KLRes.FileCorrupted);
349 					byte btType = (byte)iType;
350 					if(btType == (byte)VdType.None) break;
351 
352 					int cbName = MemUtil.BytesToInt32(MemUtil.Read(ms, 4));
353 					byte[] pbName = MemUtil.Read(ms, cbName);
354 					if(pbName.Length != cbName)
355 						throw new EndOfStreamException(KLRes.FileCorrupted);
356 					string strName = StrUtil.Utf8.GetString(pbName);
357 
358 					int cbValue = MemUtil.BytesToInt32(MemUtil.Read(ms, 4));
359 					byte[] pbValue = MemUtil.Read(ms, cbValue);
360 					if(pbValue.Length != cbValue)
361 						throw new EndOfStreamException(KLRes.FileCorrupted);
362 
363 					switch(btType)
364 					{
365 						case (byte)VdType.UInt32:
366 							if(cbValue == 4)
367 								d.SetUInt32(strName, MemUtil.BytesToUInt32(pbValue));
368 							else { Debug.Assert(false); }
369 							break;
370 
371 						case (byte)VdType.UInt64:
372 							if(cbValue == 8)
373 								d.SetUInt64(strName, MemUtil.BytesToUInt64(pbValue));
374 							else { Debug.Assert(false); }
375 							break;
376 
377 						case (byte)VdType.Bool:
378 							if(cbValue == 1)
379 								d.SetBool(strName, (pbValue[0] != 0));
380 							else { Debug.Assert(false); }
381 							break;
382 
383 						case (byte)VdType.Int32:
384 							if(cbValue == 4)
385 								d.SetInt32(strName, MemUtil.BytesToInt32(pbValue));
386 							else { Debug.Assert(false); }
387 							break;
388 
389 						case (byte)VdType.Int64:
390 							if(cbValue == 8)
391 								d.SetInt64(strName, MemUtil.BytesToInt64(pbValue));
392 							else { Debug.Assert(false); }
393 							break;
394 
395 						case (byte)VdType.String:
396 							d.SetString(strName, StrUtil.Utf8.GetString(pbValue));
397 							break;
398 
399 						case (byte)VdType.ByteArray:
400 							d.SetByteArray(strName, pbValue);
401 							break;
402 
403 						default:
404 							Debug.Assert(false); // Unknown type
405 							break;
406 					}
407 				}
408 
409 				Debug.Assert(ms.ReadByte() < 0);
410 			}
411 
412 			return d;
413 		}
414 	}
415 }
416