1 // Copyright Vladimir Prus 2004. 2 // Copyright (c) 2005-2012 Hartmut Kaiser 3 // Distributed under the Boost Software License, Version 1.0. 4 // (See accompanying file LICENSE_1_0.txt 5 // or copy at http://www.boost.org/LICENSE_1_0.txt) 6 7 // hpxinspect:nodeprecatedinclude:boost/shared_ptr.hpp 8 // hpxinspect:nodeprecatedname:boost::shared_ptr 9 10 #ifndef HPX_DLL_WINDOWS_HPP_HK_2005_11_06 11 #define HPX_DLL_WINDOWS_HPP_HK_2005_11_06 12 13 #include <hpx/config.hpp> 14 #include <hpx/error_code.hpp> 15 #include <hpx/throw_exception.hpp> 16 #include <hpx/util/assert.hpp> 17 #include <hpx/util/plugin/config.hpp> 18 19 #include <boost/filesystem/convenience.hpp> 20 #include <boost/filesystem/path.hpp> 21 #include <boost/shared_ptr.hpp> 22 23 #include <iostream> 24 #include <sstream> 25 #include <stdexcept> 26 #include <string> 27 #include <type_traits> 28 #include <utility> 29 30 #include <Shlwapi.h> 31 #include <windows.h> 32 33 #if !defined(HPX_MSVC) && !defined(HPX_MINGW) 34 #error "This file shouldn't be included directly, use the file hpx/util/plugin/dll.hpp only." 35 #endif 36 37 /////////////////////////////////////////////////////////////////////////////// 38 namespace hpx { namespace util { namespace plugin { 39 40 namespace detail 41 { 42 template<typename T> 43 struct free_dll 44 { free_dllhpx::util::plugin::detail::free_dll45 free_dll(HMODULE h) : h(h) {} 46 operator ()hpx::util::plugin::detail::free_dll47 void operator()(T) 48 { 49 if (nullptr != h) 50 FreeLibrary(h); 51 } 52 53 HMODULE h; 54 }; 55 } 56 57 class dll 58 { 59 public: dll()60 dll() 61 : dll_handle(nullptr) 62 {} 63 dll(dll const & rhs)64 dll(dll const& rhs) 65 : dll_name(rhs.dll_name), map_name(rhs.map_name), dll_handle(nullptr) 66 {} 67 dll(std::string const & libname)68 dll(std::string const& libname) 69 : dll_name(libname), map_name(""), dll_handle(nullptr) 70 { 71 // map_name defaults to dll base name 72 namespace fs = boost::filesystem; 73 74 #if BOOST_FILESYSTEM_VERSION == 2 75 fs::path dll_path(dll_name, fs::native); 76 #else 77 fs::path dll_path(dll_name); 78 #endif 79 map_name = fs::basename(dll_path); 80 } 81 load_library(error_code & ec=throws)82 void load_library(error_code& ec = throws) 83 { 84 LoadLibrary(ec); 85 } 86 dll(std::string const & libname,std::string const & mapname)87 dll(std::string const& libname, std::string const& mapname) 88 : dll_name(libname), map_name(mapname), dll_handle(nullptr) 89 {} 90 dll(dll && rhs)91 dll(dll && rhs) 92 : dll_name(std::move(rhs.dll_name)) 93 , map_name(std::move(rhs.map_name)) 94 , dll_handle(rhs.dll_handle) 95 { 96 rhs.dll_handle = nullptr; 97 } 98 operator =(dll const & rhs)99 dll &operator=(dll const & rhs) 100 { 101 if (this != &rhs) { 102 // free any existing dll_handle 103 FreeLibrary(); 104 105 // load the library for this instance of the dll class 106 dll_name = rhs.dll_name; 107 map_name = rhs.map_name; 108 LoadLibrary(); 109 } 110 return *this; 111 } 112 operator =(dll && rhs)113 dll &operator=(dll && rhs) 114 { 115 if (&rhs != this) { 116 dll_name = std::move(rhs.dll_name); 117 map_name = std::move(rhs.map_name); 118 dll_handle = rhs.dll_handle; 119 rhs.dll_handle = nullptr; 120 } 121 return *this; 122 } 123 ~dll()124 ~dll() 125 { 126 FreeLibrary(); 127 } 128 get_name() const129 std::string get_name() const { return dll_name; } get_mapname() const130 std::string get_mapname() const { return map_name; } 131 132 template<typename SymbolType, typename Deleter> 133 std::pair<SymbolType, Deleter> get(std::string const & symbol_name,error_code & ec=throws) const134 get(std::string const& symbol_name, error_code& ec = throws) const 135 { 136 const_cast<dll&>(*this).LoadLibrary(ec); 137 // make sure everything is initialized 138 if (ec) return std::pair<SymbolType, Deleter>(); 139 140 static_assert( 141 std::is_pointer<SymbolType>::value, 142 "std::is_pointer<SymbolType>::value"); 143 144 // Cast the to right type. 145 SymbolType address = (SymbolType)GetProcAddress 146 (dll_handle, symbol_name.c_str()); 147 if (nullptr == address) 148 { 149 std::ostringstream str; 150 str << "Hpx.Plugin: Unable to locate the exported symbol name '" 151 << symbol_name << "' in the shared library '" 152 << dll_name << "'"; 153 154 // report error 155 HPX_THROWS_IF(ec, dynamic_link_failure, 156 "plugin::get", str.str()); 157 return std::pair<SymbolType, Deleter>(); 158 } 159 160 // Open the library. Yes, we do it on every access to 161 // a symbol, the LoadLibrary function increases the refcnt of the dll 162 // so in the end the dll class holds one refcnt and so does every 163 // symbol. 164 HMODULE handle = ::LoadLibraryA(dll_name.c_str()); 165 if (!handle) { 166 std::ostringstream str; 167 str << "Hpx.Plugin: Could not open shared library '" 168 << dll_name << "'"; 169 170 // report error 171 HPX_THROWS_IF(ec, filesystem_error, 172 "plugin::get", str.str()); 173 return std::pair<SymbolType, Deleter>(); 174 } 175 HPX_ASSERT(handle == dll_handle); 176 177 return std::make_pair(address, detail::free_dll<SymbolType>(handle)); 178 } 179 keep_alive(error_code & ec=throws)180 void keep_alive(error_code& ec = throws) 181 { 182 LoadLibrary(ec, true); 183 } 184 185 protected: LoadLibrary(error_code & ec=throws,bool force=false)186 void LoadLibrary(error_code& ec = throws, bool force = false) 187 { 188 if (!dll_handle || force) 189 { 190 if (dll_name.empty()) { 191 // load main module 192 char buffer[_MAX_PATH]; 193 ::GetModuleFileNameA(nullptr, buffer, sizeof(buffer)); 194 dll_name = buffer; 195 } 196 197 dll_handle = ::LoadLibraryA(dll_name.c_str()); 198 if (!dll_handle) { 199 std::ostringstream str; 200 str << "Hpx.Plugin: Could not open shared library '" 201 << dll_name << "'"; 202 203 HPX_THROWS_IF(ec, filesystem_error, 204 "plugin::LoadLibrary", 205 str.str()); 206 return; 207 } 208 } 209 210 if (&ec != &throws) 211 ec = make_success_code(); 212 } 213 214 public: get_directory(error_code & ec=throws) const215 std::string get_directory(error_code& ec = throws) const 216 { 217 char buffer[_MAX_PATH] = { '\0' }; 218 219 const_cast<dll&>(*this).LoadLibrary(ec); 220 // make sure everything is initialized 221 if (ec) return buffer; 222 223 DWORD name_length = 224 GetModuleFileNameA(dll_handle, buffer, sizeof(buffer)); 225 226 if (name_length <= 0) { 227 std::ostringstream str; 228 str << "Hpx.Plugin: Could not extract path the shared " 229 "library '" << dll_name << "' has been loaded from."; 230 231 HPX_THROWS_IF(ec, filesystem_error, 232 "plugin::get_directory", str.str()); 233 return buffer; 234 } 235 236 // extract the directory name 237 PathRemoveFileSpecA(buffer); 238 239 if (&ec != &throws) 240 ec = make_success_code(); 241 242 return buffer; 243 } 244 245 protected: FreeLibrary()246 void FreeLibrary() 247 { 248 if (nullptr != dll_handle) 249 ::FreeLibrary(dll_handle); 250 } 251 252 private: 253 std::string dll_name; 254 std::string map_name; 255 HMODULE dll_handle; 256 }; 257 258 /////////////////////////////////////////////////////////////////////////////// 259 }}} 260 261 #endif 262 263