1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7 #define INITGUID
8
9 #include "mozilla/mscom/MainThreadHandoff.h"
10
11 #include <utility>
12
13 #include "mozilla/Assertions.h"
14 #include "mozilla/Attributes.h"
15 #include "mozilla/DebugOnly.h"
16 #include "mozilla/ThreadLocal.h"
17 #include "mozilla/TimeStamp.h"
18 #include "mozilla/Unused.h"
19 #include "mozilla/mscom/AgileReference.h"
20 #include "mozilla/mscom/InterceptorLog.h"
21 #include "mozilla/mscom/Registration.h"
22 #include "mozilla/mscom/Utils.h"
23 #include "nsProxyRelease.h"
24 #include "nsThreadUtils.h"
25
26 using mozilla::DebugOnly;
27 using mozilla::Unused;
28 using mozilla::mscom::AgileReference;
29
30 namespace {
31
32 class MOZ_NON_TEMPORARY_CLASS InParamWalker : private ICallFrameWalker {
33 public:
InParamWalker()34 InParamWalker() : mPreHandoff(true) {}
35
SetHandoffDone()36 void SetHandoffDone() {
37 mPreHandoff = false;
38 mAgileRefsItr = mAgileRefs.begin();
39 }
40
Walk(ICallFrame * aFrame)41 HRESULT Walk(ICallFrame* aFrame) {
42 MOZ_ASSERT(aFrame);
43 if (!aFrame) {
44 return E_INVALIDARG;
45 }
46
47 return aFrame->WalkFrame(CALLFRAME_WALK_IN, this);
48 }
49
50 private:
51 // IUnknown
QueryInterface(REFIID aIid,void ** aOutInterface)52 STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override {
53 if (!aOutInterface) {
54 return E_INVALIDARG;
55 }
56 *aOutInterface = nullptr;
57
58 if (aIid == IID_IUnknown || aIid == IID_ICallFrameWalker) {
59 *aOutInterface = static_cast<ICallFrameWalker*>(this);
60 return S_OK;
61 }
62
63 return E_NOINTERFACE;
64 }
65
AddRef()66 STDMETHODIMP_(ULONG) AddRef() override { return 2; }
67
Release()68 STDMETHODIMP_(ULONG) Release() override { return 1; }
69
70 // ICallFrameWalker
OnWalkInterface(REFIID aIid,PVOID * aInterface,BOOL aIn,BOOL aOut)71 STDMETHODIMP OnWalkInterface(REFIID aIid, PVOID* aInterface, BOOL aIn,
72 BOOL aOut) override {
73 MOZ_ASSERT(aIn);
74 if (!aIn) {
75 return E_UNEXPECTED;
76 }
77
78 IUnknown* origInterface = static_cast<IUnknown*>(*aInterface);
79 if (!origInterface) {
80 // Nothing to do
81 return S_OK;
82 }
83
84 if (mPreHandoff) {
85 mAgileRefs.AppendElement(AgileReference(aIid, origInterface));
86 return S_OK;
87 }
88
89 MOZ_ASSERT(mAgileRefsItr != mAgileRefs.end());
90 if (mAgileRefsItr == mAgileRefs.end()) {
91 return E_UNEXPECTED;
92 }
93
94 HRESULT hr = mAgileRefsItr->Resolve(aIid, aInterface);
95 MOZ_ASSERT(SUCCEEDED(hr));
96 if (SUCCEEDED(hr)) {
97 ++mAgileRefsItr;
98 }
99
100 return hr;
101 }
102
103 InParamWalker(const InParamWalker&) = delete;
104 InParamWalker(InParamWalker&&) = delete;
105 InParamWalker& operator=(const InParamWalker&) = delete;
106 InParamWalker& operator=(InParamWalker&&) = delete;
107
108 private:
109 bool mPreHandoff;
110 AutoTArray<AgileReference, 1> mAgileRefs;
111 nsTArray<AgileReference>::iterator mAgileRefsItr;
112 };
113
114 class HandoffRunnable : public mozilla::Runnable {
115 public:
HandoffRunnable(ICallFrame * aCallFrame,IUnknown * aTargetInterface)116 explicit HandoffRunnable(ICallFrame* aCallFrame, IUnknown* aTargetInterface)
117 : Runnable("HandoffRunnable"),
118 mCallFrame(aCallFrame),
119 mTargetInterface(aTargetInterface),
120 mResult(E_UNEXPECTED) {
121 DebugOnly<HRESULT> hr = mInParamWalker.Walk(aCallFrame);
122 MOZ_ASSERT(SUCCEEDED(hr));
123 }
124
Run()125 NS_IMETHOD Run() override {
126 mInParamWalker.SetHandoffDone();
127 // We declare hr a DebugOnly because if mInParamWalker.Walk() fails, then
128 // mCallFrame->Invoke will fail anyway.
129 DebugOnly<HRESULT> hr = mInParamWalker.Walk(mCallFrame);
130 MOZ_ASSERT(SUCCEEDED(hr));
131 mResult = mCallFrame->Invoke(mTargetInterface);
132 return NS_OK;
133 }
134
GetResult() const135 HRESULT GetResult() const { return mResult; }
136
137 private:
138 ICallFrame* mCallFrame;
139 InParamWalker mInParamWalker;
140 IUnknown* mTargetInterface;
141 HRESULT mResult;
142 };
143
144 class MOZ_RAII SavedCallFrame final {
145 public:
SavedCallFrame(mozilla::NotNull<ICallFrame * > aFrame)146 explicit SavedCallFrame(mozilla::NotNull<ICallFrame*> aFrame)
147 : mCallFrame(aFrame) {
148 static const bool sIsInit = tlsFrame.init();
149 MOZ_ASSERT(sIsInit);
150 MOZ_ASSERT(!tlsFrame.get());
151 tlsFrame.set(this);
152 Unused << sIsInit;
153 }
154
~SavedCallFrame()155 ~SavedCallFrame() {
156 MOZ_ASSERT(tlsFrame.get());
157 tlsFrame.set(nullptr);
158 }
159
GetIidAndMethod(mozilla::NotNull<IID * > aIid,mozilla::NotNull<ULONG * > aMethod) const160 HRESULT GetIidAndMethod(mozilla::NotNull<IID*> aIid,
161 mozilla::NotNull<ULONG*> aMethod) const {
162 return mCallFrame->GetIIDAndMethod(aIid, aMethod);
163 }
164
Get()165 static const SavedCallFrame& Get() {
166 SavedCallFrame* saved = tlsFrame.get();
167 MOZ_ASSERT(saved);
168
169 return *saved;
170 }
171
172 SavedCallFrame(const SavedCallFrame&) = delete;
173 SavedCallFrame(SavedCallFrame&&) = delete;
174 SavedCallFrame& operator=(const SavedCallFrame&) = delete;
175 SavedCallFrame& operator=(SavedCallFrame&&) = delete;
176
177 private:
178 ICallFrame* mCallFrame;
179
180 private:
181 static MOZ_THREAD_LOCAL(SavedCallFrame*) tlsFrame;
182 };
183
184 MOZ_THREAD_LOCAL(SavedCallFrame*) SavedCallFrame::tlsFrame;
185
186 class MOZ_RAII LogEvent final {
187 public:
LogEvent()188 LogEvent() : mCallStart(mozilla::TimeStamp::Now()) {}
189
~LogEvent()190 ~LogEvent() {
191 if (mCapturedFrame.IsEmpty()) {
192 return;
193 }
194
195 mozilla::TimeStamp callEnd(mozilla::TimeStamp::Now());
196 mozilla::TimeDuration totalTime(callEnd - mCallStart);
197 mozilla::TimeDuration overhead(totalTime - mGeckoDuration -
198 mCaptureDuration);
199
200 mozilla::mscom::InterceptorLog::Event(mCapturedFrame, overhead,
201 mGeckoDuration);
202 }
203
CaptureFrame(ICallFrame * aFrame,IUnknown * aTarget,const mozilla::TimeDuration & aGeckoDuration)204 void CaptureFrame(ICallFrame* aFrame, IUnknown* aTarget,
205 const mozilla::TimeDuration& aGeckoDuration) {
206 mozilla::TimeStamp captureStart(mozilla::TimeStamp::Now());
207
208 mozilla::mscom::InterceptorLog::CaptureFrame(aFrame, aTarget,
209 mCapturedFrame);
210 mGeckoDuration = aGeckoDuration;
211
212 mozilla::TimeStamp captureEnd(mozilla::TimeStamp::Now());
213
214 // Make sure that the time we spent in CaptureFrame isn't charged against
215 // overall overhead
216 mCaptureDuration = captureEnd - captureStart;
217 }
218
219 LogEvent(const LogEvent&) = delete;
220 LogEvent(LogEvent&&) = delete;
221 LogEvent& operator=(const LogEvent&) = delete;
222 LogEvent& operator=(LogEvent&&) = delete;
223
224 private:
225 mozilla::TimeStamp mCallStart;
226 mozilla::TimeDuration mGeckoDuration;
227 mozilla::TimeDuration mCaptureDuration;
228 nsAutoCString mCapturedFrame;
229 };
230
231 } // anonymous namespace
232
233 namespace mozilla {
234 namespace mscom {
235
236 /* static */
Create(IHandlerProvider * aHandlerProvider,IInterceptorSink ** aOutput)237 HRESULT MainThreadHandoff::Create(IHandlerProvider* aHandlerProvider,
238 IInterceptorSink** aOutput) {
239 RefPtr<MainThreadHandoff> handoff(new MainThreadHandoff(aHandlerProvider));
240 return handoff->QueryInterface(IID_IInterceptorSink, (void**)aOutput);
241 }
242
MainThreadHandoff(IHandlerProvider * aHandlerProvider)243 MainThreadHandoff::MainThreadHandoff(IHandlerProvider* aHandlerProvider)
244 : mRefCnt(0), mHandlerProvider(aHandlerProvider) {}
245
~MainThreadHandoff()246 MainThreadHandoff::~MainThreadHandoff() { MOZ_ASSERT(NS_IsMainThread()); }
247
248 HRESULT
QueryInterface(REFIID riid,void ** ppv)249 MainThreadHandoff::QueryInterface(REFIID riid, void** ppv) {
250 IUnknown* punk = nullptr;
251 if (!ppv) {
252 return E_INVALIDARG;
253 }
254
255 if (riid == IID_IUnknown || riid == IID_ICallFrameEvents ||
256 riid == IID_IInterceptorSink || riid == IID_IMainThreadHandoff) {
257 punk = static_cast<IMainThreadHandoff*>(this);
258 } else if (riid == IID_ICallFrameWalker) {
259 punk = static_cast<ICallFrameWalker*>(this);
260 }
261
262 *ppv = punk;
263 if (!punk) {
264 return E_NOINTERFACE;
265 }
266
267 punk->AddRef();
268 return S_OK;
269 }
270
271 ULONG
AddRef()272 MainThreadHandoff::AddRef() {
273 return (ULONG)InterlockedIncrement((LONG*)&mRefCnt);
274 }
275
276 ULONG
Release()277 MainThreadHandoff::Release() {
278 ULONG newRefCnt = (ULONG)InterlockedDecrement((LONG*)&mRefCnt);
279 if (newRefCnt == 0) {
280 // It is possible for the last Release() call to happen off-main-thread.
281 // If so, we need to dispatch an event to delete ourselves.
282 if (NS_IsMainThread()) {
283 delete this;
284 } else {
285 // We need to delete this object on the main thread, but we aren't on the
286 // main thread right now, so we send a reference to ourselves to the main
287 // thread to be re-released there.
288 RefPtr<MainThreadHandoff> self = this;
289 NS_ReleaseOnMainThread("MainThreadHandoff", self.forget());
290 }
291 }
292 return newRefCnt;
293 }
294
295 HRESULT
FixIServiceProvider(ICallFrame * aFrame)296 MainThreadHandoff::FixIServiceProvider(ICallFrame* aFrame) {
297 MOZ_ASSERT(aFrame);
298
299 CALLFRAMEPARAMINFO iidOutParamInfo;
300 HRESULT hr = aFrame->GetParamInfo(1, &iidOutParamInfo);
301 if (FAILED(hr)) {
302 return hr;
303 }
304
305 VARIANT varIfaceOut;
306 hr = aFrame->GetParam(2, &varIfaceOut);
307 if (FAILED(hr)) {
308 return hr;
309 }
310
311 MOZ_ASSERT(varIfaceOut.vt == (VT_UNKNOWN | VT_BYREF));
312 if (varIfaceOut.vt != (VT_UNKNOWN | VT_BYREF)) {
313 return DISP_E_BADVARTYPE;
314 }
315
316 IID** iidOutParam =
317 reinterpret_cast<IID**>(static_cast<BYTE*>(aFrame->GetStackLocation()) +
318 iidOutParamInfo.stackOffset);
319
320 return OnWalkInterface(**iidOutParam,
321 reinterpret_cast<void**>(varIfaceOut.ppunkVal), FALSE,
322 TRUE);
323 }
324
325 HRESULT
OnCall(ICallFrame * aFrame)326 MainThreadHandoff::OnCall(ICallFrame* aFrame) {
327 LogEvent logEvent;
328
329 // (1) Get info about the method call
330 HRESULT hr;
331 IID iid;
332 ULONG method;
333 hr = aFrame->GetIIDAndMethod(&iid, &method);
334 if (FAILED(hr)) {
335 return hr;
336 }
337
338 RefPtr<IInterceptor> interceptor;
339 hr = mInterceptor->Resolve(IID_IInterceptor,
340 (void**)getter_AddRefs(interceptor));
341 if (FAILED(hr)) {
342 return hr;
343 }
344
345 InterceptorTargetPtr<IUnknown> targetInterface;
346 hr = interceptor->GetTargetForIID(iid, targetInterface);
347 if (FAILED(hr)) {
348 return hr;
349 }
350
351 // (2) Execute the method call synchronously on the main thread
352 RefPtr<HandoffRunnable> handoffInfo(
353 new HandoffRunnable(aFrame, targetInterface.get()));
354 MainThreadInvoker invoker;
355 if (!invoker.Invoke(do_AddRef(handoffInfo))) {
356 MOZ_ASSERT(false);
357 return E_UNEXPECTED;
358 }
359 hr = handoffInfo->GetResult();
360 MOZ_ASSERT(SUCCEEDED(hr));
361 if (FAILED(hr)) {
362 return hr;
363 }
364
365 // (3) Capture *before* wrapping outputs so that the log will contain pointers
366 // to the true target interface, not the wrapped ones.
367 logEvent.CaptureFrame(aFrame, targetInterface.get(), invoker.GetDuration());
368
369 // (4) Scan the function call for outparams that contain interface pointers.
370 // Those will need to be wrapped with MainThreadHandoff so that they too will
371 // be exeuted on the main thread.
372
373 hr = aFrame->GetReturnValue();
374 if (FAILED(hr)) {
375 // If the call resulted in an error then there's not going to be anything
376 // that needs to be wrapped.
377 return S_OK;
378 }
379
380 if (iid == IID_IServiceProvider) {
381 // The only possible method index for IID_IServiceProvider is for
382 // QueryService at index 3; its other methods are inherited from IUnknown
383 // and are not processed here.
384 MOZ_ASSERT(method == 3);
385 // (5) If our interface is IServiceProvider, we need to manually ensure
386 // that the correct IID is provided for the interface outparam in
387 // IServiceProvider::QueryService.
388 hr = FixIServiceProvider(aFrame);
389 if (FAILED(hr)) {
390 return hr;
391 }
392 } else if (const ArrayData* arrayData = FindArrayData(iid, method)) {
393 // (6) Unfortunately ICallFrame::WalkFrame does not correctly handle array
394 // outparams. Instead, we find out whether anybody has called
395 // mscom::RegisterArrayData to supply array parameter information and use it
396 // if available. This is a terrible hack, but it works for the short term.
397 // In the longer term we want to be able to use COM proxy/stub metadata to
398 // resolve array information for us.
399 hr = FixArrayElements(aFrame, *arrayData);
400 if (FAILED(hr)) {
401 return hr;
402 }
403 } else {
404 SavedCallFrame savedFrame(WrapNotNull(aFrame));
405
406 // (7) Scan the outputs looking for any outparam interfaces that need
407 // wrapping. NB: WalkFrame does not correctly handle array outparams. It
408 // processes the first element of an array but not the remaining elements
409 // (if any).
410 hr = aFrame->WalkFrame(CALLFRAME_WALK_OUT, this);
411 if (FAILED(hr)) {
412 return hr;
413 }
414 }
415
416 return S_OK;
417 }
418
ResolveArrayPtr(VARIANT & aVariant)419 static PVOID ResolveArrayPtr(VARIANT& aVariant) {
420 if (!(aVariant.vt & VT_BYREF)) {
421 return nullptr;
422 }
423 return aVariant.byref;
424 }
425
ResolveInterfacePtr(PVOID aArrayPtr,VARTYPE aVartype,LONG aIndex)426 static PVOID* ResolveInterfacePtr(PVOID aArrayPtr, VARTYPE aVartype,
427 LONG aIndex) {
428 if (aVartype != (VT_VARIANT | VT_BYREF)) {
429 IUnknown** ifaceArray = reinterpret_cast<IUnknown**>(aArrayPtr);
430 return reinterpret_cast<PVOID*>(&ifaceArray[aIndex]);
431 }
432 VARIANT* variantArray = reinterpret_cast<VARIANT*>(aArrayPtr);
433 VARIANT& element = variantArray[aIndex];
434 return &element.byref;
435 }
436
437 HRESULT
FixArrayElements(ICallFrame * aFrame,const ArrayData & aArrayData)438 MainThreadHandoff::FixArrayElements(ICallFrame* aFrame,
439 const ArrayData& aArrayData) {
440 // Extract the array length
441 VARIANT paramVal;
442 VariantInit(¶mVal);
443 HRESULT hr = aFrame->GetParam(aArrayData.mLengthParamIndex, ¶mVal);
444 MOZ_ASSERT(SUCCEEDED(hr) && (paramVal.vt == (VT_I4 | VT_BYREF) ||
445 paramVal.vt == (VT_UI4 | VT_BYREF)));
446 if (FAILED(hr) || (paramVal.vt != (VT_I4 | VT_BYREF) &&
447 paramVal.vt != (VT_UI4 | VT_BYREF))) {
448 return hr;
449 }
450
451 const LONG arrayLength = *(paramVal.plVal);
452 if (!arrayLength) {
453 // Nothing to do
454 return S_OK;
455 }
456
457 // Extract the array parameter
458 VariantInit(¶mVal);
459 PVOID arrayPtr = nullptr;
460 hr = aFrame->GetParam(aArrayData.mArrayParamIndex, ¶mVal);
461 if (hr == DISP_E_BADVARTYPE) {
462 // ICallFrame::GetParam is not able to coerce the param into a VARIANT.
463 // That's ok, we can try to do it ourselves.
464 CALLFRAMEPARAMINFO paramInfo;
465 hr = aFrame->GetParamInfo(aArrayData.mArrayParamIndex, ¶mInfo);
466 if (FAILED(hr)) {
467 return hr;
468 }
469 PVOID stackBase = aFrame->GetStackLocation();
470 if (aArrayData.mFlag == ArrayData::Flag::eAllocatedByServer) {
471 // In order for the server to allocate the array's buffer and store it in
472 // an outparam, the parameter must be typed as Type***. Since the base
473 // of the array is Type*, we must dereference twice.
474 arrayPtr = **reinterpret_cast<PVOID**>(
475 reinterpret_cast<PBYTE>(stackBase) + paramInfo.stackOffset);
476 } else {
477 // We dereference because we need to obtain the value of a parameter
478 // from a stack offset. This pointer is the base of the array.
479 arrayPtr = *reinterpret_cast<PVOID*>(reinterpret_cast<PBYTE>(stackBase) +
480 paramInfo.stackOffset);
481 }
482 } else if (FAILED(hr)) {
483 return hr;
484 } else {
485 arrayPtr = ResolveArrayPtr(paramVal);
486 }
487
488 MOZ_ASSERT(arrayPtr);
489 if (!arrayPtr) {
490 return DISP_E_BADVARTYPE;
491 }
492
493 // We walk the elements of the array and invoke OnWalkInterface to wrap each
494 // one, just as ICallFrame::WalkFrame would do.
495 for (LONG index = 0; index < arrayLength; ++index) {
496 hr = OnWalkInterface(aArrayData.mArrayParamIid,
497 ResolveInterfacePtr(arrayPtr, paramVal.vt, index),
498 FALSE, TRUE);
499 if (FAILED(hr)) {
500 return hr;
501 }
502 }
503 return S_OK;
504 }
505
506 HRESULT
SetInterceptor(IWeakReference * aInterceptor)507 MainThreadHandoff::SetInterceptor(IWeakReference* aInterceptor) {
508 mInterceptor = aInterceptor;
509 return S_OK;
510 }
511
512 HRESULT
GetHandler(NotNull<CLSID * > aHandlerClsid)513 MainThreadHandoff::GetHandler(NotNull<CLSID*> aHandlerClsid) {
514 if (!mHandlerProvider) {
515 return E_NOTIMPL;
516 }
517
518 return mHandlerProvider->GetHandler(aHandlerClsid);
519 }
520
521 HRESULT
GetHandlerPayloadSize(NotNull<IInterceptor * > aInterceptor,NotNull<DWORD * > aOutPayloadSize)522 MainThreadHandoff::GetHandlerPayloadSize(NotNull<IInterceptor*> aInterceptor,
523 NotNull<DWORD*> aOutPayloadSize) {
524 if (!mHandlerProvider) {
525 return E_NOTIMPL;
526 }
527 return mHandlerProvider->GetHandlerPayloadSize(aInterceptor, aOutPayloadSize);
528 }
529
530 HRESULT
WriteHandlerPayload(NotNull<IInterceptor * > aInterceptor,NotNull<IStream * > aStream)531 MainThreadHandoff::WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor,
532 NotNull<IStream*> aStream) {
533 if (!mHandlerProvider) {
534 return E_NOTIMPL;
535 }
536 return mHandlerProvider->WriteHandlerPayload(aInterceptor, aStream);
537 }
538
539 REFIID
MarshalAs(REFIID aIid)540 MainThreadHandoff::MarshalAs(REFIID aIid) {
541 if (!mHandlerProvider) {
542 return aIid;
543 }
544 return mHandlerProvider->MarshalAs(aIid);
545 }
546
547 HRESULT
DisconnectHandlerRemotes()548 MainThreadHandoff::DisconnectHandlerRemotes() {
549 if (!mHandlerProvider) {
550 return E_NOTIMPL;
551 }
552
553 return mHandlerProvider->DisconnectHandlerRemotes();
554 }
555
556 HRESULT
IsInterfaceMaybeSupported(REFIID aIid)557 MainThreadHandoff::IsInterfaceMaybeSupported(REFIID aIid) {
558 if (!mHandlerProvider) {
559 return S_OK;
560 }
561 return mHandlerProvider->IsInterfaceMaybeSupported(aIid);
562 }
563
564 HRESULT
OnWalkInterface(REFIID aIid,PVOID * aInterface,BOOL aIsInParam,BOOL aIsOutParam)565 MainThreadHandoff::OnWalkInterface(REFIID aIid, PVOID* aInterface,
566 BOOL aIsInParam, BOOL aIsOutParam) {
567 MOZ_ASSERT(aInterface && aIsOutParam);
568 if (!aInterface || !aIsOutParam) {
569 return E_UNEXPECTED;
570 }
571
572 // Adopt aInterface for the time being. We can't touch its refcount off
573 // the main thread, so we'll use STAUniquePtr so that we can safely
574 // Release() it if necessary.
575 STAUniquePtr<IUnknown> origInterface(static_cast<IUnknown*>(*aInterface));
576 *aInterface = nullptr;
577
578 if (!origInterface) {
579 // Nothing to wrap.
580 return S_OK;
581 }
582
583 // First make sure that aInterface isn't a proxy - we don't want to wrap
584 // those.
585 if (IsProxy(origInterface.get())) {
586 *aInterface = origInterface.release();
587 return S_OK;
588 }
589
590 RefPtr<IInterceptor> interceptor;
591 HRESULT hr = mInterceptor->Resolve(IID_IInterceptor,
592 (void**)getter_AddRefs(interceptor));
593 MOZ_ASSERT(SUCCEEDED(hr));
594 if (FAILED(hr)) {
595 return hr;
596 }
597
598 // Now make sure that origInterface isn't referring to the same IUnknown
599 // as an interface that we are already managing. We can determine this by
600 // querying (NOT casting!) both objects for IUnknown and then comparing the
601 // resulting pointers.
602 InterceptorTargetPtr<IUnknown> existingTarget;
603 hr = interceptor->GetTargetForIID(aIid, existingTarget);
604 if (SUCCEEDED(hr)) {
605 // We'll start by checking the raw pointers. If they are equal, then the
606 // objects are equal. OTOH, if they differ, we must compare their
607 // IUnknown pointers to know for sure.
608 bool areTargetsEqual = existingTarget.get() == origInterface.get();
609
610 if (!areTargetsEqual) {
611 // This check must be done on the main thread
612 auto checkFn = [&existingTarget, &origInterface,
613 &areTargetsEqual]() -> void {
614 RefPtr<IUnknown> unkExisting;
615 HRESULT hrExisting = existingTarget->QueryInterface(
616 IID_IUnknown, (void**)getter_AddRefs(unkExisting));
617 RefPtr<IUnknown> unkNew;
618 HRESULT hrNew = origInterface->QueryInterface(
619 IID_IUnknown, (void**)getter_AddRefs(unkNew));
620 areTargetsEqual =
621 SUCCEEDED(hrExisting) && SUCCEEDED(hrNew) && unkExisting == unkNew;
622 };
623
624 MainThreadInvoker invoker;
625 invoker.Invoke(NS_NewRunnableFunction(
626 "MainThreadHandoff::OnWalkInterface", checkFn));
627 }
628
629 if (areTargetsEqual) {
630 // The existing interface and the new interface both belong to the same
631 // target object. Let's just use the existing one.
632 void* intercepted = nullptr;
633 hr = interceptor->GetInterceptorForIID(aIid, &intercepted);
634 MOZ_ASSERT(SUCCEEDED(hr));
635 if (FAILED(hr)) {
636 return hr;
637 }
638 *aInterface = intercepted;
639 return S_OK;
640 }
641 }
642
643 IID effectiveIid = aIid;
644
645 RefPtr<IHandlerProvider> payload;
646 if (mHandlerProvider) {
647 if (aIid == IID_IUnknown) {
648 const SavedCallFrame& curFrame = SavedCallFrame::Get();
649
650 IID callIid;
651 ULONG callMethod;
652 hr = curFrame.GetIidAndMethod(WrapNotNull(&callIid),
653 WrapNotNull(&callMethod));
654 if (FAILED(hr)) {
655 return hr;
656 }
657
658 effectiveIid =
659 mHandlerProvider->GetEffectiveOutParamIid(callIid, callMethod);
660 }
661
662 hr = mHandlerProvider->NewInstance(
663 effectiveIid, ToInterceptorTargetPtr(origInterface),
664 WrapNotNull((IHandlerProvider**)getter_AddRefs(payload)));
665 MOZ_ASSERT(SUCCEEDED(hr));
666 if (FAILED(hr)) {
667 return hr;
668 }
669 }
670
671 // Now create a new MainThreadHandoff wrapper...
672 RefPtr<IInterceptorSink> handoff;
673 hr = MainThreadHandoff::Create(payload, getter_AddRefs(handoff));
674 MOZ_ASSERT(SUCCEEDED(hr));
675 if (FAILED(hr)) {
676 return hr;
677 }
678
679 REFIID interceptorIid =
680 payload ? payload->MarshalAs(effectiveIid) : effectiveIid;
681
682 RefPtr<IUnknown> wrapped;
683 hr = Interceptor::Create(std::move(origInterface), handoff, interceptorIid,
684 getter_AddRefs(wrapped));
685 MOZ_ASSERT(SUCCEEDED(hr));
686 if (FAILED(hr)) {
687 return hr;
688 }
689
690 // And replace the original interface pointer with the wrapped one.
691 wrapped.forget(reinterpret_cast<IUnknown**>(aInterface));
692
693 return S_OK;
694 }
695
696 } // namespace mscom
697 } // namespace mozilla
698