1 /*
2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 #include <windows.h>
27 #include <shellapi.h>
28 
29 #include "WinSysInfo.h"
30 #include "FileUtils.h"
31 #include "WinErrorHandling.h"
32 
33 #pragma comment(lib, "Shell32")
34 
35 namespace SysInfo {
36 
getTempDir()37 tstring getTempDir() {
38     std::vector<TCHAR> buffer(MAX_PATH);
39     DWORD res = GetTempPath(static_cast<DWORD>(buffer.size()), buffer.data());
40     if (res > buffer.size()) {
41         buffer.resize(res);
42         GetTempPath(static_cast<DWORD>(buffer.size()), buffer.data());
43     }
44     return FileUtils::removeTrailingSlash(buffer.data());
45 }
46 
47 namespace {
48 
49 template <class Func>
getSystemDirImpl(Func func,const std::string & label)50 tstring getSystemDirImpl(Func func, const std::string& label) {
51     std::vector<TCHAR> buffer(MAX_PATH);
52     for (int i=0; i<2; i++) {
53         DWORD res = func(buffer.data(), static_cast<DWORD>(buffer.size()));
54         if (!res) {
55             JP_THROW(SysError(label + " failed", func));
56         }
57         if (res < buffer.size()) {
58             return FileUtils::removeTrailingSlash(buffer.data());
59         }
60         buffer.resize(res + 1);
61     }
62     JP_THROW("Unexpected reply from" + label);
63 }
64 
65 } // namespace
66 
getSystem32Dir()67 tstring getSystem32Dir() {
68     return getSystemDirImpl(GetSystemDirectory, "GetSystemDirectory");
69 }
70 
getWIPath()71 tstring getWIPath() {
72     return FileUtils::mkpath() << getSystem32Dir() << _T("msiexec.exe");
73 }
74 
75 namespace {
76 
getModulePath(HMODULE h)77 tstring getModulePath(HMODULE h)
78 {
79     std::vector<TCHAR> buf(MAX_PATH);
80     DWORD len = 0;
81     while (true) {
82         len = GetModuleFileName(h, buf.data(), (DWORD)buf.size());
83         if (len < buf.size()) {
84             break;
85         }
86         // buffer is too small, increase it
87         buf.resize(buf.size() * 2);
88     }
89 
90     if (len == 0) {
91         // error occured
92         JP_THROW(SysError("GetModuleFileName failed", GetModuleFileName));
93     }
94     return tstring(buf.begin(), buf.begin() + len);
95 }
96 
97 } // namespace
98 
getProcessModulePath()99 tstring getProcessModulePath() {
100     return FileUtils::toAbsolutePath(getModulePath(NULL));
101 }
102 
getCurrentModuleHandle()103 HMODULE getCurrentModuleHandle()
104 {
105     // get module handle for the address of this function
106     LPCWSTR address = reinterpret_cast<LPCWSTR>(getCurrentModuleHandle);
107     HMODULE hmodule = NULL;
108     if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS
109             | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, address, &hmodule))
110     {
111         JP_THROW(SysError(tstrings::any() << "GetModuleHandleExW failed",
112                 GetModuleHandleExW));
113     }
114     return hmodule;
115 }
116 
setEnvVariable(const tstring & name,const tstring & value)117 void setEnvVariable(const tstring& name, const tstring& value)
118 {
119     if (!SetEnvironmentVariable(name.c_str(), value.c_str())) {
120         JP_THROW(SysError(tstrings::any()
121                 << "SetEnvironmentVariable("
122                 << name << ", " << value
123                 << ") failed", SetEnvironmentVariable));
124     }
125 }
126 
getCurrentModulePath()127 tstring getCurrentModulePath()
128 {
129     return getModulePath(getCurrentModuleHandle());
130 }
131 
getCommandArgs(CommandArgProgramNameMode progNameMode)132 tstring_array getCommandArgs(CommandArgProgramNameMode progNameMode)
133 {
134     int argc = 0;
135     tstring_array result;
136 
137     LPWSTR *parsedArgs = CommandLineToArgvW(GetCommandLineW(), &argc);
138     if (parsedArgs == NULL) {
139         JP_THROW(SysError("CommandLineToArgvW failed", CommandLineToArgvW));
140     }
141     // the 1st element contains program name
142     for (int i = progNameMode == ExcludeProgramName ? 1 : 0; i < argc; i++) {
143         result.push_back(parsedArgs[i]);
144     }
145     LocalFree(parsedArgs);
146 
147     return result;
148 }
149 
150 namespace {
151 
getEnvVariableImpl(const tstring & name,bool * errorOccured=0)152 tstring getEnvVariableImpl(const tstring& name, bool* errorOccured=0) {
153     std::vector<TCHAR> buf(10);
154     SetLastError(ERROR_SUCCESS);
155     const DWORD size = GetEnvironmentVariable(name.c_str(), buf.data(),
156                                                             DWORD(buf.size()));
157     if (GetLastError() == ERROR_ENVVAR_NOT_FOUND) {
158         if (errorOccured) {
159             *errorOccured = true;
160             return tstring();
161         }
162         JP_THROW(SysError(tstrings::any() << "GetEnvironmentVariable("
163             << name << ") failed. Variable not set", GetEnvironmentVariable));
164     }
165 
166     if (size > buf.size()) {
167         buf.resize(size);
168         GetEnvironmentVariable(name.c_str(), buf.data(), DWORD(buf.size()));
169         if (GetLastError() != ERROR_SUCCESS) {
170             if (errorOccured) {
171                 *errorOccured = true;
172                 return tstring();
173             }
174             JP_THROW(SysError(tstrings::any() << "GetEnvironmentVariable("
175                             << name << ") failed", GetEnvironmentVariable));
176         }
177     }
178 
179     if (errorOccured) {
180         *errorOccured = false;
181     }
182     return tstring(buf.data());
183 }
184 
185 } // namespace
186 
getEnvVariable(const tstring & name)187 tstring getEnvVariable(const tstring& name) {
188     return getEnvVariableImpl(name);
189 }
190 
getEnvVariable(const std::nothrow_t &,const tstring & name,const tstring & defValue)191 tstring getEnvVariable(const std::nothrow_t&, const tstring& name,
192                                                     const tstring& defValue) {
193     bool errorOccured = false;
194     const tstring reply = getEnvVariableImpl(name, &errorOccured);
195     if (errorOccured) {
196         return defValue;
197     }
198     return reply;
199 }
200 
isEnvVariableSet(const tstring & name)201 bool isEnvVariableSet(const tstring& name) {
202     TCHAR unused[1];
203     SetLastError(ERROR_SUCCESS);
204     GetEnvironmentVariable(name.c_str(), unused, _countof(unused));
205     return GetLastError() != ERROR_ENVVAR_NOT_FOUND;
206 }
207 
208 } // end of namespace SysInfo
209