1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include "Module.h"
8 
9 #include <stdlib.h>
10 
11 #include <ktmw32.h>
12 #include <memory.h>
13 #include <rpc.h>
14 
15 #include "mozilla/ArrayUtils.h"
16 #include "mozilla/Assertions.h"
17 #include "mozilla/mscom/Utils.h"
18 #include "mozilla/Range.h"
19 #include "nsWindowsHelpers.h"
20 
21 template <size_t N>
LiteralToRange(const wchar_t (& aArg)[N])22 static const mozilla::Range<const wchar_t> LiteralToRange(
23     const wchar_t (&aArg)[N]) {
24   return mozilla::Range(aArg, N);
25 }
26 
27 namespace mozilla {
28 namespace mscom {
29 
30 ULONG Module::sRefCount = 0;
31 
SubkeyNameFromClassType(const Module::ClassType aClassType)32 static const wchar_t* SubkeyNameFromClassType(
33     const Module::ClassType aClassType) {
34   switch (aClassType) {
35     case Module::ClassType::InprocServer:
36       return L"InprocServer32";
37     case Module::ClassType::InprocHandler:
38       return L"InprocHandler32";
39     default:
40       MOZ_CRASH("Unknown ClassType");
41       return nullptr;
42   }
43 }
44 
ThreadingModelAsString(const Module::ThreadingModel aThreadingModel)45 static const Range<const wchar_t> ThreadingModelAsString(
46     const Module::ThreadingModel aThreadingModel) {
47   switch (aThreadingModel) {
48     case Module::ThreadingModel::DedicatedUiThreadOnly:
49       return LiteralToRange(L"Apartment");
50     case Module::ThreadingModel::MultiThreadedApartmentOnly:
51       return LiteralToRange(L"Free");
52     case Module::ThreadingModel::DedicatedUiThreadXorMultiThreadedApartment:
53       return LiteralToRange(L"Both");
54     case Module::ThreadingModel::AllThreadsAllApartments:
55       return LiteralToRange(L"Neutral");
56     default:
57       MOZ_CRASH("Unknown ThreadingModel");
58       return Range<const wchar_t>();
59   }
60 }
61 
62 /* static */
Register(const CLSID * const * aClsids,const size_t aNumClsids,const ThreadingModel aThreadingModel,const ClassType aClassType,const GUID * const aAppId)63 HRESULT Module::Register(const CLSID* const* aClsids, const size_t aNumClsids,
64                          const ThreadingModel aThreadingModel,
65                          const ClassType aClassType, const GUID* const aAppId) {
66   MOZ_ASSERT(aClsids && aNumClsids);
67   if (!aClsids || !aNumClsids) {
68     return E_INVALIDARG;
69   }
70 
71   const wchar_t* inprocName = SubkeyNameFromClassType(aClassType);
72 
73   const Range<const wchar_t> threadingModelStr =
74       ThreadingModelAsString(aThreadingModel);
75   const DWORD threadingModelStrLenBytesInclNul =
76       threadingModelStr.length() * sizeof(wchar_t);
77 
78   wchar_t strAppId[kGuidRegFormatCharLenInclNul] = {};
79   if (aAppId) {
80     GUIDToString(*aAppId, strAppId);
81   }
82 
83   // Obtain the full path to this DLL
84   HMODULE thisModule;
85   if (!::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
86                                 GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
87                             reinterpret_cast<LPCWSTR>(&Module::CanUnload),
88                             &thisModule)) {
89     return HRESULT_FROM_WIN32(::GetLastError());
90   }
91 
92   wchar_t absThisModulePath[MAX_PATH + 1] = {};
93   DWORD actualPathLenCharsExclNul = ::GetModuleFileNameW(
94       thisModule, absThisModulePath, ArrayLength(absThisModulePath));
95   if (!actualPathLenCharsExclNul ||
96       actualPathLenCharsExclNul == ArrayLength(absThisModulePath)) {
97     return HRESULT_FROM_WIN32(::GetLastError());
98   }
99   const DWORD actualPathLenBytesInclNul =
100       (actualPathLenCharsExclNul + 1) * sizeof(wchar_t);
101 
102   // Use the name of this DLL as the name of the transaction
103   wchar_t txnName[_MAX_FNAME] = {};
104   if (_wsplitpath_s(absThisModulePath, nullptr, 0, nullptr, 0, txnName,
105                     ArrayLength(txnName), nullptr, 0)) {
106     return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
107   }
108 
109   // Manipulate the registry using a transaction so that any failures are
110   // rolled back.
111   nsAutoHandle txn(::CreateTransaction(
112       nullptr, nullptr, TRANSACTION_DO_NOT_PROMOTE, 0, 0, 0, txnName));
113   if (txn.get() == INVALID_HANDLE_VALUE) {
114     return HRESULT_FROM_WIN32(::GetLastError());
115   }
116 
117   HRESULT hr;
118   LSTATUS status;
119 
120   // A single DLL may serve multiple components. For each CLSID, we register
121   // this DLL as its server and, when an AppId is specified, set up a reference
122   // from the CLSID to the specified AppId.
123   for (size_t idx = 0; idx < aNumClsids; ++idx) {
124     if (!aClsids[idx]) {
125       return E_INVALIDARG;
126     }
127 
128     wchar_t clsidKeyPath[256];
129     hr = BuildClsidPath(*aClsids[idx], clsidKeyPath);
130     if (FAILED(hr)) {
131       return hr;
132     }
133 
134     // Create the CLSID key
135     HKEY rawClsidKey;
136     status = ::RegCreateKeyTransactedW(
137         HKEY_LOCAL_MACHINE, clsidKeyPath, 0, nullptr, REG_OPTION_NON_VOLATILE,
138         KEY_ALL_ACCESS, nullptr, &rawClsidKey, nullptr, txn, nullptr);
139     if (status != ERROR_SUCCESS) {
140       return HRESULT_FROM_WIN32(status);
141     }
142     nsAutoRegKey clsidKey(rawClsidKey);
143 
144     if (aAppId) {
145       // This value associates the registered CLSID with the specified AppID
146       status = ::RegSetValueExW(clsidKey, L"AppID", 0, REG_SZ,
147                                 reinterpret_cast<const BYTE*>(strAppId),
148                                 ArrayLength(strAppId) * sizeof(wchar_t));
149       if (status != ERROR_SUCCESS) {
150         return HRESULT_FROM_WIN32(status);
151       }
152     }
153 
154     HKEY rawInprocKey;
155     status = ::RegCreateKeyTransactedW(
156         clsidKey, inprocName, 0, nullptr, REG_OPTION_NON_VOLATILE,
157         KEY_ALL_ACCESS, nullptr, &rawInprocKey, nullptr, txn, nullptr);
158     if (status != ERROR_SUCCESS) {
159       return HRESULT_FROM_WIN32(status);
160     }
161     nsAutoRegKey inprocKey(rawInprocKey);
162 
163     // Set the component's path to this DLL
164     status = ::RegSetValueExW(inprocKey, nullptr, 0, REG_EXPAND_SZ,
165                               reinterpret_cast<const BYTE*>(absThisModulePath),
166                               actualPathLenBytesInclNul);
167     if (status != ERROR_SUCCESS) {
168       return HRESULT_FROM_WIN32(status);
169     }
170 
171     status = ::RegSetValueExW(
172         inprocKey, L"ThreadingModel", 0, REG_SZ,
173         reinterpret_cast<const BYTE*>(threadingModelStr.begin().get()),
174         threadingModelStrLenBytesInclNul);
175     if (status != ERROR_SUCCESS) {
176       return HRESULT_FROM_WIN32(status);
177     }
178   }
179 
180   if (aAppId) {
181     // When specified, we must also create a key for the AppID.
182     wchar_t appidKeyPath[256];
183     hr = BuildAppidPath(*aAppId, appidKeyPath);
184     if (FAILED(hr)) {
185       return hr;
186     }
187 
188     HKEY rawAppidKey;
189     status = ::RegCreateKeyTransactedW(
190         HKEY_LOCAL_MACHINE, appidKeyPath, 0, nullptr, REG_OPTION_NON_VOLATILE,
191         KEY_ALL_ACCESS, nullptr, &rawAppidKey, nullptr, txn, nullptr);
192     if (status != ERROR_SUCCESS) {
193       return HRESULT_FROM_WIN32(status);
194     }
195     nsAutoRegKey appidKey(rawAppidKey);
196 
197     // Setting DllSurrogate to a null or empty string indicates to Windows that
198     // we want to use the default surrogate (i.e. dllhost.exe) to load our DLL.
199     status =
200         ::RegSetValueExW(appidKey, L"DllSurrogate", 0, REG_SZ,
201                          reinterpret_cast<const BYTE*>(L""), sizeof(wchar_t));
202     if (status != ERROR_SUCCESS) {
203       return HRESULT_FROM_WIN32(status);
204     }
205   }
206 
207   if (!::CommitTransaction(txn)) {
208     return HRESULT_FROM_WIN32(::GetLastError());
209   }
210 
211   return S_OK;
212 }
213 
214 /**
215  * Unfortunately the registry transaction APIs are not as well-developed for
216  * deleting things as they are for creating them. We just use RegDeleteTree
217  * for the implementation of this method.
218  */
Deregister(const CLSID * const * aClsids,const size_t aNumClsids,const GUID * const aAppId)219 HRESULT Module::Deregister(const CLSID* const* aClsids, const size_t aNumClsids,
220                            const GUID* const aAppId) {
221   MOZ_ASSERT(aClsids && aNumClsids);
222   if (!aClsids || !aNumClsids) {
223     return E_INVALIDARG;
224   }
225 
226   HRESULT hr;
227   LSTATUS status;
228 
229   // Delete the key for each CLSID. This will also delete any references to
230   // the AppId.
231   for (size_t idx = 0; idx < aNumClsids; ++idx) {
232     if (!aClsids[idx]) {
233       return E_INVALIDARG;
234     }
235 
236     wchar_t clsidKeyPath[256];
237     hr = BuildClsidPath(*aClsids[idx], clsidKeyPath);
238     if (FAILED(hr)) {
239       return hr;
240     }
241 
242     status = ::RegDeleteTreeW(HKEY_LOCAL_MACHINE, clsidKeyPath);
243     // We allow the deletion to succeed if the key was already gone
244     if (status != ERROR_SUCCESS && status != ERROR_FILE_NOT_FOUND) {
245       return HRESULT_FROM_WIN32(status);
246     }
247   }
248 
249   // Now delete the AppID key, if desired.
250   if (aAppId) {
251     wchar_t appidKeyPath[256];
252     hr = BuildAppidPath(*aAppId, appidKeyPath);
253     if (FAILED(hr)) {
254       return hr;
255     }
256 
257     status = ::RegDeleteTreeW(HKEY_LOCAL_MACHINE, appidKeyPath);
258     // We allow the deletion to succeed if the key was already gone
259     if (status != ERROR_SUCCESS && status != ERROR_FILE_NOT_FOUND) {
260       return HRESULT_FROM_WIN32(status);
261     }
262   }
263 
264   return S_OK;
265 }
266 
267 }  // namespace mscom
268 }  // namespace mozilla
269