1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/win/scoped_variant.h"
6 
7 #include <wrl/client.h>
8 
9 #include <algorithm>
10 #include <functional>
11 
12 #include "base/check.h"
13 #include "base/logging.h"
14 #include "base/numerics/ranges.h"
15 #include "base/win/propvarutil.h"
16 #include "base/win/variant_util.h"
17 
18 namespace base {
19 namespace win {
20 
21 // Global, const instance of an empty variant.
22 const VARIANT ScopedVariant::kEmptyVariant = {{{VT_EMPTY}}};
23 
ScopedVariant(ScopedVariant && var)24 ScopedVariant::ScopedVariant(ScopedVariant&& var) {
25   var_.vt = VT_EMPTY;
26   Reset(var.Release());
27 }
28 
~ScopedVariant()29 ScopedVariant::~ScopedVariant() {
30   static_assert(sizeof(ScopedVariant) == sizeof(VARIANT), "ScopedVariantSize");
31   ::VariantClear(&var_);
32 }
33 
ScopedVariant(const wchar_t * str)34 ScopedVariant::ScopedVariant(const wchar_t* str) {
35   var_.vt = VT_EMPTY;
36   Set(str);
37 }
38 
ScopedVariant(const wchar_t * str,UINT length)39 ScopedVariant::ScopedVariant(const wchar_t* str, UINT length) {
40   var_.vt = VT_BSTR;
41   var_.bstrVal = ::SysAllocStringLen(str, length);
42 }
43 
ScopedVariant(long value,VARTYPE vt)44 ScopedVariant::ScopedVariant(long value, VARTYPE vt) {
45   var_.vt = vt;
46   var_.lVal = value;
47 }
48 
ScopedVariant(int value)49 ScopedVariant::ScopedVariant(int value) {
50   var_.vt = VT_I4;
51   var_.lVal = value;
52 }
53 
ScopedVariant(bool value)54 ScopedVariant::ScopedVariant(bool value) {
55   var_.vt = VT_BOOL;
56   var_.boolVal = value ? VARIANT_TRUE : VARIANT_FALSE;
57 }
58 
ScopedVariant(double value,VARTYPE vt)59 ScopedVariant::ScopedVariant(double value, VARTYPE vt) {
60   DCHECK(vt == VT_R8 || vt == VT_DATE);
61   var_.vt = vt;
62   var_.dblVal = value;
63 }
64 
ScopedVariant(IDispatch * dispatch)65 ScopedVariant::ScopedVariant(IDispatch* dispatch) {
66   var_.vt = VT_EMPTY;
67   Set(dispatch);
68 }
69 
ScopedVariant(IUnknown * unknown)70 ScopedVariant::ScopedVariant(IUnknown* unknown) {
71   var_.vt = VT_EMPTY;
72   Set(unknown);
73 }
74 
ScopedVariant(SAFEARRAY * safearray)75 ScopedVariant::ScopedVariant(SAFEARRAY* safearray) {
76   var_.vt = VT_EMPTY;
77   Set(safearray);
78 }
79 
ScopedVariant(const VARIANT & var)80 ScopedVariant::ScopedVariant(const VARIANT& var) {
81   var_.vt = VT_EMPTY;
82   Set(var);
83 }
84 
Reset(const VARIANT & var)85 void ScopedVariant::Reset(const VARIANT& var) {
86   if (&var != &var_) {
87     ::VariantClear(&var_);
88     var_ = var;
89   }
90 }
91 
Release()92 VARIANT ScopedVariant::Release() {
93   VARIANT var = var_;
94   var_.vt = VT_EMPTY;
95   return var;
96 }
97 
Swap(ScopedVariant & var)98 void ScopedVariant::Swap(ScopedVariant& var) {
99   VARIANT tmp = var_;
100   var_ = var.var_;
101   var.var_ = tmp;
102 }
103 
Receive()104 VARIANT* ScopedVariant::Receive() {
105   DCHECK(!IsLeakableVarType(var_.vt)) << "variant leak. type: " << var_.vt;
106   return &var_;
107 }
108 
Copy() const109 VARIANT ScopedVariant::Copy() const {
110   VARIANT ret = {{{VT_EMPTY}}};
111   ::VariantCopy(&ret, &var_);
112   return ret;
113 }
114 
Compare(const VARIANT & other,bool ignore_case) const115 int ScopedVariant::Compare(const VARIANT& other, bool ignore_case) const {
116   DCHECK(!V_ISARRAY(&var_))
117       << "Comparison is not supported when |this| owns a SAFEARRAY";
118   DCHECK(!V_ISARRAY(&other))
119       << "Comparison is not supported when |other| owns a SAFEARRAY";
120 
121   const bool this_is_empty = var_.vt == VT_EMPTY || var_.vt == VT_NULL;
122   const bool other_is_empty = other.vt == VT_EMPTY || other.vt == VT_NULL;
123 
124   // 1. VT_NULL and VT_EMPTY is always considered less-than any other VARTYPE.
125   if (this_is_empty)
126     return other_is_empty ? 0 : -1;
127   if (other_is_empty)
128     return 1;
129 
130   // 2. If both VARIANTS have either VT_UNKNOWN or VT_DISPATCH even if the
131   //    VARTYPEs do not match, the address of its IID_IUnknown is compared to
132   //    guarantee a logical ordering even though it is not a meaningful order.
133   //    e.g. (a.Compare(b) != b.Compare(a)) unless (a == b).
134   const bool this_is_unknown = var_.vt == VT_UNKNOWN || var_.vt == VT_DISPATCH;
135   const bool other_is_unknown =
136       other.vt == VT_UNKNOWN || other.vt == VT_DISPATCH;
137   if (this_is_unknown && other_is_unknown) {
138     // https://docs.microsoft.com/en-us/windows/win32/com/rules-for-implementing-queryinterface
139     // Query IID_IUnknown to determine whether the two variants point
140     // to the same instance of an object
141     Microsoft::WRL::ComPtr<IUnknown> this_unknown;
142     Microsoft::WRL::ComPtr<IUnknown> other_unknown;
143     V_UNKNOWN(&var_)->QueryInterface(IID_PPV_ARGS(&this_unknown));
144     V_UNKNOWN(&other)->QueryInterface(IID_PPV_ARGS(&other_unknown));
145     if (this_unknown.Get() == other_unknown.Get())
146       return 0;
147     // std::less for any pointer type yields a strict total order even if the
148     // built-in operator< does not.
149     return std::less<>{}(this_unknown.Get(), other_unknown.Get()) ? -1 : 1;
150   }
151 
152   // 3. If the VARTYPEs do not match, then the value of the VARTYPE is compared.
153   if (V_VT(&var_) != V_VT(&other))
154     return (V_VT(&var_) < V_VT(&other)) ? -1 : 1;
155 
156   const VARTYPE shared_vartype = V_VT(&var_);
157   // 4. Comparing VT_BSTR values is a lexicographical comparison of the contents
158   //    of the BSTR, taking into account |ignore_case|.
159   if (shared_vartype == VT_BSTR) {
160     ULONG flags = ignore_case ? NORM_IGNORECASE : 0;
161     HRESULT hr =
162         ::VarBstrCmp(V_BSTR(&var_), V_BSTR(&other), LOCALE_USER_DEFAULT, flags);
163     DCHECK(SUCCEEDED(hr) && hr != VARCMP_NULL)
164         << "unsupported variant comparison: " << var_.vt << " and " << other.vt;
165 
166     switch (hr) {
167       case VARCMP_LT:
168         return -1;
169       case VARCMP_GT:
170       case VARCMP_NULL:
171         return 1;
172       default:
173         return 0;
174     }
175   }
176 
177   // 5. Otherwise returns the lexicographical comparison of the values held by
178   //    the two VARIANTS that share the same VARTYPE.
179   return ::VariantCompare(var_, other);
180 }
181 
Set(const wchar_t * str)182 void ScopedVariant::Set(const wchar_t* str) {
183   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
184   var_.vt = VT_BSTR;
185   var_.bstrVal = ::SysAllocString(str);
186 }
187 
Set(int8_t i8)188 void ScopedVariant::Set(int8_t i8) {
189   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
190   var_.vt = VT_I1;
191   var_.cVal = i8;
192 }
193 
Set(uint8_t ui8)194 void ScopedVariant::Set(uint8_t ui8) {
195   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
196   var_.vt = VT_UI1;
197   var_.bVal = ui8;
198 }
199 
Set(int16_t i16)200 void ScopedVariant::Set(int16_t i16) {
201   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
202   var_.vt = VT_I2;
203   var_.iVal = i16;
204 }
205 
Set(uint16_t ui16)206 void ScopedVariant::Set(uint16_t ui16) {
207   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
208   var_.vt = VT_UI2;
209   var_.uiVal = ui16;
210 }
211 
Set(int32_t i32)212 void ScopedVariant::Set(int32_t i32) {
213   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
214   var_.vt = VT_I4;
215   var_.lVal = i32;
216 }
217 
Set(uint32_t ui32)218 void ScopedVariant::Set(uint32_t ui32) {
219   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
220   var_.vt = VT_UI4;
221   var_.ulVal = ui32;
222 }
223 
Set(int64_t i64)224 void ScopedVariant::Set(int64_t i64) {
225   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
226   var_.vt = VT_I8;
227   var_.llVal = i64;
228 }
229 
Set(uint64_t ui64)230 void ScopedVariant::Set(uint64_t ui64) {
231   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
232   var_.vt = VT_UI8;
233   var_.ullVal = ui64;
234 }
235 
Set(float r32)236 void ScopedVariant::Set(float r32) {
237   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
238   var_.vt = VT_R4;
239   var_.fltVal = r32;
240 }
241 
Set(double r64)242 void ScopedVariant::Set(double r64) {
243   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
244   var_.vt = VT_R8;
245   var_.dblVal = r64;
246 }
247 
SetDate(DATE date)248 void ScopedVariant::SetDate(DATE date) {
249   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
250   var_.vt = VT_DATE;
251   var_.date = date;
252 }
253 
Set(IDispatch * disp)254 void ScopedVariant::Set(IDispatch* disp) {
255   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
256   var_.vt = VT_DISPATCH;
257   var_.pdispVal = disp;
258   if (disp)
259     disp->AddRef();
260 }
261 
Set(bool b)262 void ScopedVariant::Set(bool b) {
263   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
264   var_.vt = VT_BOOL;
265   var_.boolVal = b ? VARIANT_TRUE : VARIANT_FALSE;
266 }
267 
Set(IUnknown * unk)268 void ScopedVariant::Set(IUnknown* unk) {
269   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
270   var_.vt = VT_UNKNOWN;
271   var_.punkVal = unk;
272   if (unk)
273     unk->AddRef();
274 }
275 
Set(SAFEARRAY * array)276 void ScopedVariant::Set(SAFEARRAY* array) {
277   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
278   if (SUCCEEDED(::SafeArrayGetVartype(array, &var_.vt))) {
279     var_.vt |= VT_ARRAY;
280     var_.parray = array;
281   } else {
282     DCHECK(!array) << "Unable to determine safearray vartype";
283     var_.vt = VT_EMPTY;
284   }
285 }
286 
Set(const VARIANT & var)287 void ScopedVariant::Set(const VARIANT& var) {
288   DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
289   if (FAILED(::VariantCopy(&var_, &var))) {
290     DLOG(ERROR) << "VariantCopy failed";
291     var_.vt = VT_EMPTY;
292   }
293 }
294 
operator =(ScopedVariant && var)295 ScopedVariant& ScopedVariant::operator=(ScopedVariant&& var) {
296   if (var.ptr() != &var_)
297     Reset(var.Release());
298   return *this;
299 }
300 
operator =(const VARIANT & var)301 ScopedVariant& ScopedVariant::operator=(const VARIANT& var) {
302   if (&var != &var_) {
303     VariantClear(&var_);
304     Set(var);
305   }
306   return *this;
307 }
308 
IsLeakableVarType(VARTYPE vt)309 bool ScopedVariant::IsLeakableVarType(VARTYPE vt) {
310   bool leakable = false;
311   switch (vt & VT_TYPEMASK) {
312     case VT_BSTR:
313     case VT_DISPATCH:
314     // we treat VT_VARIANT as leakable to err on the safe side.
315     case VT_VARIANT:
316     case VT_UNKNOWN:
317     case VT_SAFEARRAY:
318 
319     // very rarely used stuff (if ever):
320     case VT_VOID:
321     case VT_PTR:
322     case VT_CARRAY:
323     case VT_USERDEFINED:
324     case VT_LPSTR:
325     case VT_LPWSTR:
326     case VT_RECORD:
327     case VT_INT_PTR:
328     case VT_UINT_PTR:
329     case VT_FILETIME:
330     case VT_BLOB:
331     case VT_STREAM:
332     case VT_STORAGE:
333     case VT_STREAMED_OBJECT:
334     case VT_STORED_OBJECT:
335     case VT_BLOB_OBJECT:
336     case VT_VERSIONED_STREAM:
337     case VT_BSTR_BLOB:
338       leakable = true;
339       break;
340   }
341 
342   if (!leakable && (vt & VT_ARRAY) != 0) {
343     leakable = true;
344   }
345 
346   return leakable;
347 }
348 
349 }  // namespace win
350 }  // namespace base
351