1 // Copyright Vladimir Prus 2004.
2 // Copyright (c) 2005-2012 Hartmut Kaiser
3 // Distributed under the Boost Software License, Version 1.0.
4 // (See accompanying file LICENSE_1_0.txt
5 // or copy at http://www.boost.org/LICENSE_1_0.txt)
6 
7 // hpxinspect:nodeprecatedinclude:boost/shared_ptr.hpp
8 // hpxinspect:nodeprecatedname:boost::shared_ptr
9 
10 #ifndef HPX_DLL_WINDOWS_HPP_HK_2005_11_06
11 #define HPX_DLL_WINDOWS_HPP_HK_2005_11_06
12 
13 #include <hpx/config.hpp>
14 #include <hpx/error_code.hpp>
15 #include <hpx/throw_exception.hpp>
16 #include <hpx/util/assert.hpp>
17 #include <hpx/util/plugin/config.hpp>
18 
19 #include <boost/filesystem/convenience.hpp>
20 #include <boost/filesystem/path.hpp>
21 #include <boost/shared_ptr.hpp>
22 
23 #include <iostream>
24 #include <sstream>
25 #include <stdexcept>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 
30 #include <Shlwapi.h>
31 #include <windows.h>
32 
33 #if !defined(HPX_MSVC) && !defined(HPX_MINGW)
34 #error "This file shouldn't be included directly, use the file hpx/util/plugin/dll.hpp only."
35 #endif
36 
37 ///////////////////////////////////////////////////////////////////////////////
38 namespace hpx { namespace util { namespace plugin {
39 
40     namespace detail
41     {
42         template<typename T>
43         struct free_dll
44         {
free_dllhpx::util::plugin::detail::free_dll45             free_dll(HMODULE h) : h(h) {}
46 
operator ()hpx::util::plugin::detail::free_dll47             void operator()(T)
48             {
49                 if (nullptr != h)
50                     FreeLibrary(h);
51             }
52 
53             HMODULE h;
54         };
55     }
56 
57     class dll
58     {
59     public:
dll()60         dll()
61         :   dll_handle(nullptr)
62         {}
63 
dll(dll const & rhs)64         dll(dll const& rhs)
65         :   dll_name(rhs.dll_name), map_name(rhs.map_name), dll_handle(nullptr)
66         {}
67 
dll(std::string const & libname)68         dll(std::string const& libname)
69         :   dll_name(libname), map_name(""), dll_handle(nullptr)
70         {
71             // map_name defaults to dll base name
72             namespace fs = boost::filesystem;
73 
74 #if BOOST_FILESYSTEM_VERSION == 2
75             fs::path dll_path(dll_name, fs::native);
76 #else
77             fs::path dll_path(dll_name);
78 #endif
79             map_name = fs::basename(dll_path);
80         }
81 
load_library(error_code & ec=throws)82         void load_library(error_code& ec = throws)
83         {
84             LoadLibrary(ec);
85         }
86 
dll(std::string const & libname,std::string const & mapname)87         dll(std::string const& libname, std::string const& mapname)
88         :   dll_name(libname), map_name(mapname), dll_handle(nullptr)
89         {}
90 
dll(dll && rhs)91         dll(dll && rhs)
92           : dll_name(std::move(rhs.dll_name))
93           , map_name(std::move(rhs.map_name))
94           , dll_handle(rhs.dll_handle)
95         {
96             rhs.dll_handle = nullptr;
97         }
98 
operator =(dll const & rhs)99         dll &operator=(dll const & rhs)
100         {
101             if (this != &rhs) {
102             //  free any existing dll_handle
103                 FreeLibrary();
104 
105             //  load the library for this instance of the dll class
106                 dll_name = rhs.dll_name;
107                 map_name = rhs.map_name;
108                 LoadLibrary();
109             }
110             return *this;
111         }
112 
operator =(dll && rhs)113         dll &operator=(dll && rhs)
114         {
115             if (&rhs != this) {
116                 dll_name = std::move(rhs.dll_name);
117                 map_name = std::move(rhs.map_name);
118                 dll_handle = rhs.dll_handle;
119                 rhs.dll_handle = nullptr;
120             }
121             return *this;
122         }
123 
~dll()124         ~dll()
125         {
126             FreeLibrary();
127         }
128 
get_name() const129         std::string get_name() const { return dll_name; }
get_mapname() const130         std::string get_mapname() const { return map_name; }
131 
132         template<typename SymbolType, typename Deleter>
133         std::pair<SymbolType, Deleter>
get(std::string const & symbol_name,error_code & ec=throws) const134         get(std::string const& symbol_name, error_code& ec = throws) const
135         {
136             const_cast<dll&>(*this).LoadLibrary(ec);
137             // make sure everything is initialized
138             if (ec) return std::pair<SymbolType, Deleter>();
139 
140             static_assert(
141                 std::is_pointer<SymbolType>::value,
142                 "std::is_pointer<SymbolType>::value");
143 
144             // Cast the to right type.
145             SymbolType address = (SymbolType)GetProcAddress
146                 (dll_handle, symbol_name.c_str());
147             if (nullptr == address)
148             {
149                 std::ostringstream str;
150                 str << "Hpx.Plugin: Unable to locate the exported symbol name '"
151                     << symbol_name << "' in the shared library '"
152                     << dll_name << "'";
153 
154                 // report error
155                 HPX_THROWS_IF(ec, dynamic_link_failure,
156                     "plugin::get", str.str());
157                 return std::pair<SymbolType, Deleter>();
158             }
159 
160             // Open the library. Yes, we do it on every access to
161             // a symbol, the LoadLibrary function increases the refcnt of the dll
162             // so in the end the dll class holds one refcnt and so does every
163             // symbol.
164             HMODULE handle = ::LoadLibraryA(dll_name.c_str());
165             if (!handle) {
166                 std::ostringstream str;
167                 str << "Hpx.Plugin: Could not open shared library '"
168                     << dll_name << "'";
169 
170                 // report error
171                 HPX_THROWS_IF(ec, filesystem_error,
172                     "plugin::get", str.str());
173                 return std::pair<SymbolType, Deleter>();
174             }
175             HPX_ASSERT(handle == dll_handle);
176 
177             return std::make_pair(address, detail::free_dll<SymbolType>(handle));
178         }
179 
keep_alive(error_code & ec=throws)180         void keep_alive(error_code& ec = throws)
181         {
182             LoadLibrary(ec, true);
183         }
184 
185     protected:
LoadLibrary(error_code & ec=throws,bool force=false)186         void LoadLibrary(error_code& ec = throws, bool force = false)
187         {
188             if (!dll_handle || force)
189             {
190                 if (dll_name.empty()) {
191                 // load main module
192                     char buffer[_MAX_PATH];
193                     ::GetModuleFileNameA(nullptr, buffer, sizeof(buffer));
194                     dll_name = buffer;
195                 }
196 
197                 dll_handle = ::LoadLibraryA(dll_name.c_str());
198                 if (!dll_handle) {
199                     std::ostringstream str;
200                     str << "Hpx.Plugin: Could not open shared library '"
201                         << dll_name << "'";
202 
203                     HPX_THROWS_IF(ec, filesystem_error,
204                         "plugin::LoadLibrary",
205                         str.str());
206                     return;
207                 }
208             }
209 
210             if (&ec != &throws)
211                 ec = make_success_code();
212         }
213 
214     public:
get_directory(error_code & ec=throws) const215         std::string get_directory(error_code& ec = throws) const
216         {
217             char buffer[_MAX_PATH] = { '\0' };
218 
219             const_cast<dll&>(*this).LoadLibrary(ec);
220             // make sure everything is initialized
221             if (ec) return buffer;
222 
223             DWORD name_length =
224                 GetModuleFileNameA(dll_handle, buffer, sizeof(buffer));
225 
226             if (name_length <= 0) {
227                 std::ostringstream str;
228                 str << "Hpx.Plugin: Could not extract path the shared "
229                        "library '" << dll_name << "' has been loaded from.";
230 
231                 HPX_THROWS_IF(ec, filesystem_error,
232                     "plugin::get_directory", str.str());
233                 return buffer;
234             }
235 
236             // extract the directory name
237             PathRemoveFileSpecA(buffer);
238 
239             if (&ec != &throws)
240                 ec = make_success_code();
241 
242             return buffer;
243         }
244 
245     protected:
FreeLibrary()246         void FreeLibrary()
247         {
248             if (nullptr != dll_handle)
249                 ::FreeLibrary(dll_handle);
250         }
251 
252     private:
253         std::string dll_name;
254         std::string map_name;
255         HMODULE dll_handle;
256     };
257 
258 ///////////////////////////////////////////////////////////////////////////////
259 }}}
260 
261 #endif
262 
263