1 /*
2   KeePass Password Safe - The Open-Source Password Manager
3   Copyright (C) 2003-2021 Dominik Reichl <dominik.reichl@t-online.de>
4 
5   This program is free software; you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation; either version 2 of the License, or
8   (at your option) any later version.
9 
10   This program is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14 
15   You should have received a copy of the GNU General Public License
16   along with this program; if not, write to the Free Software
17   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
18 */
19 
20 using System;
21 using System.Collections.Generic;
22 using System.Text;
23 using System.Windows.Forms;
24 using System.Diagnostics;
25 
26 using KeePass.App.Configuration;
27 
28 using KeePassLib;
29 using KeePassLib.Delegates;
30 using KeePassLib.Serialization;
31 
32 namespace KeePass.UI
33 {
34 	public sealed class DocumentManagerEx
35 	{
36 		private List<PwDocument> m_vDocs = new List<PwDocument>();
37 		private PwDocument m_dsActive = new PwDocument();
38 
39 		public event EventHandler ActiveDocumentSelected;
40 
DocumentManagerEx()41 		public DocumentManagerEx()
42 		{
43 			Debug.Assert((m_vDocs != null) && (m_dsActive != null));
44 			m_vDocs.Add(m_dsActive);
45 		}
46 
47 		public PwDocument ActiveDocument
48 		{
49 			get { return m_dsActive; }
50 			set
51 			{
52 				if(value == null) { Debug.Assert(false); throw new ArgumentNullException("value"); }
53 
54 				for(int i = 0; i < m_vDocs.Count; ++i)
55 				{
56 					if(m_vDocs[i] == value)
57 					{
58 						m_dsActive = value;
59 
60 						NotifyActiveDocumentSelected();
61 						return;
62 					}
63 				}
64 
65 				throw new ArgumentException();
66 			}
67 		}
68 
69 		public PwDatabase ActiveDatabase
70 		{
71 			get { return m_dsActive.Database; }
72 			set
73 			{
74 				if(value == null) { Debug.Assert(false); throw new ArgumentNullException("value"); }
75 
76 				for(int i = 0; i < m_vDocs.Count; ++i)
77 				{
78 					if(m_vDocs[i].Database == value)
79 					{
80 						m_dsActive = m_vDocs[i];
81 
82 						NotifyActiveDocumentSelected();
83 						return;
84 					}
85 				}
86 
87 				throw new ArgumentException();
88 			}
89 		}
90 
91 		public uint DocumentCount
92 		{
93 			get { return (uint)m_vDocs.Count; }
94 		}
95 
96 		public List<PwDocument> Documents
97 		{
98 			get { return m_vDocs; }
99 		}
100 
CreateNewDocument(bool bMakeActive)101 		public PwDocument CreateNewDocument(bool bMakeActive)
102 		{
103 			PwDocument ds = new PwDocument();
104 
105 			if((m_vDocs.Count == 1) && (!m_vDocs[0].Database.IsOpen) &&
106 				(m_vDocs[0].LockedIoc.Path.Length == 0))
107 			{
108 				m_vDocs.RemoveAt(0);
109 				m_dsActive = ds;
110 			}
111 
112 			m_vDocs.Add(ds);
113 			if(bMakeActive) m_dsActive = ds;
114 
115 			NotifyActiveDocumentSelected();
116 			return ds;
117 		}
118 
CloseDatabase(PwDatabase pwDatabase)119 		public void CloseDatabase(PwDatabase pwDatabase)
120 		{
121 			int iFoundPos = -1;
122 			for(int i = 0; i < m_vDocs.Count; ++i)
123 			{
124 				if(m_vDocs[i].Database == pwDatabase)
125 				{
126 					iFoundPos = i;
127 					break;
128 				}
129 			}
130 			if(iFoundPos < 0) { Debug.Assert(false); return; }
131 
132 			bool bClosingActive = (m_vDocs[iFoundPos] == m_dsActive);
133 
134 			m_vDocs.RemoveAt(iFoundPos);
135 			if(m_vDocs.Count == 0)
136 				m_vDocs.Add(new PwDocument());
137 
138 			if(bClosingActive)
139 			{
140 				int iNewActive = Math.Min(iFoundPos, m_vDocs.Count - 1);
141 				m_dsActive = m_vDocs[iNewActive];
142 				NotifyActiveDocumentSelected();
143 			}
144 			else { Debug.Assert(m_vDocs.Contains(m_dsActive)); }
145 		}
146 
GetOpenDatabases()147 		public List<PwDatabase> GetOpenDatabases()
148 		{
149 			List<PwDatabase> list = new List<PwDatabase>();
150 
151 			foreach(PwDocument ds in m_vDocs)
152 			{
153 				if(ds.Database.IsOpen)
154 					list.Add(ds.Database);
155 			}
156 
157 			return list;
158 		}
159 
GetDocuments(int iMoveActive)160 		internal List<PwDocument> GetDocuments(int iMoveActive)
161 		{
162 			List<PwDocument> lDocs = new List<PwDocument>(m_vDocs);
163 
164 			if(iMoveActive != 0)
165 			{
166 				for(int i = 0; i < lDocs.Count; ++i)
167 				{
168 					if(lDocs[i] == m_dsActive)
169 					{
170 						lDocs.RemoveAt(i);
171 						if(iMoveActive < 0) lDocs.Insert(0, m_dsActive);
172 						else lDocs.Add(m_dsActive);
173 						break;
174 					}
175 				}
176 			}
177 
178 			return lDocs;
179 		}
180 
NotifyActiveDocumentSelected()181 		private void NotifyActiveDocumentSelected()
182 		{
183 			RememberActiveDocument();
184 
185 			if(this.ActiveDocumentSelected != null)
186 				this.ActiveDocumentSelected(null, EventArgs.Empty);
187 		}
188 
RememberActiveDocument()189 		internal void RememberActiveDocument()
190 		{
191 			if(m_dsActive == null) { Debug.Assert(false); return; }
192 
193 			if(m_dsActive.LockedIoc != null)
194 				SetLastUsedFile(m_dsActive.LockedIoc);
195 			if(m_dsActive.Database != null)
196 				SetLastUsedFile(m_dsActive.Database.IOConnectionInfo);
197 		}
198 
SetLastUsedFile(IOConnectionInfo ioc)199 		private static void SetLastUsedFile(IOConnectionInfo ioc)
200 		{
201 			if(ioc == null) { Debug.Assert(false); return; }
202 
203 			AceApplication aceApp = Program.Config.Application;
204 			if(aceApp.Start.OpenLastFile)
205 			{
206 				if(!string.IsNullOrEmpty(ioc.Path))
207 					aceApp.LastUsedFile = ioc.CloneDeep();
208 			}
209 			else aceApp.LastUsedFile = new IOConnectionInfo();
210 		}
211 
FindDocument(PwDatabase pwDatabase)212 		public PwDocument FindDocument(PwDatabase pwDatabase)
213 		{
214 			if(pwDatabase == null) throw new ArgumentNullException("pwDatabase");
215 
216 			foreach(PwDocument ds in m_vDocs)
217 			{
218 				if(ds.Database == pwDatabase) return ds;
219 			}
220 
221 			return null;
222 		}
223 
224 		/// <summary>
225 		/// Search for an entry in all opened databases. The
226 		/// entry is identified by its reference (not its UUID).
227 		/// </summary>
228 		/// <param name="peObj">Entry to search for.</param>
229 		/// <returns>Database containing the entry.</returns>
FindContainerOf(PwEntry peObj)230 		public PwDatabase FindContainerOf(PwEntry peObj)
231 		{
232 			if(peObj == null) return null; // No assert
233 
234 			PwGroup pg = peObj.ParentGroup;
235 			if(pg != null)
236 			{
237 				while(pg.ParentGroup != null) { pg = pg.ParentGroup; }
238 
239 				foreach(PwDocument ds in m_vDocs)
240 				{
241 					PwDatabase pd = ds.Database;
242 					if((pd == null) || !pd.IsOpen) continue;
243 
244 					if(object.ReferenceEquals(pd.RootGroup, pg))
245 						return pd;
246 				}
247 
248 				Debug.Assert(false);
249 			}
250 
251 			return SlowFindContainerOf(peObj);
252 		}
253 
SlowFindContainerOf(PwEntry peObj)254 		private PwDatabase SlowFindContainerOf(PwEntry peObj)
255 		{
256 			PwDatabase pdRet = null;
257 			foreach(PwDocument ds in m_vDocs)
258 			{
259 				PwDatabase pd = ds.Database;
260 				if((pd == null) || !pd.IsOpen) continue;
261 
262 				EntryHandler eh = delegate(PwEntry pe)
263 				{
264 					if(object.ReferenceEquals(pe, peObj))
265 					{
266 						pdRet = pd;
267 						return false; // Stop traversal
268 					}
269 
270 					return true;
271 				};
272 
273 				pd.RootGroup.TraverseTree(TraversalMethod.PreOrder, null, eh);
274 				if(pdRet != null) return pdRet;
275 			}
276 
277 			return null;
278 		}
279 
SafeFindContainerOf(PwEntry peObj)280 		public PwDatabase SafeFindContainerOf(PwEntry peObj)
281 		{
282 			// peObj may be null
283 			return (FindContainerOf(peObj) ?? m_dsActive.Database);
284 		}
285 	}
286 
287 	public sealed class PwDocument
288 	{
289 		private PwDatabase m_pwDb = new PwDatabase();
290 		private IOConnectionInfo m_ioLockedIoc = new IOConnectionInfo();
291 
292 		public PwDatabase Database
293 		{
294 			get { return m_pwDb; }
295 		}
296 
297 		public IOConnectionInfo LockedIoc
298 		{
299 			get { return m_ioLockedIoc; }
300 			set
301 			{
302 				if(value == null) { Debug.Assert(false); throw new ArgumentNullException("value"); }
303 				m_ioLockedIoc = value;
304 			}
305 		}
306 	}
307 }
308