enable gpu on windows 0916

This commit is contained in:
taipingchangan 2022-09-16 15:21:12 +08:00
parent 04bd77f793
commit 408e6c9189
8 changed files with 41 additions and 12 deletions

View File

@ -49,7 +49,7 @@ set(INSTALL_PY_DIR ".")
set(INSTALL_BASE_DIR ".")
set(INSTALL_BIN_DIR "bin")
set(INSTALL_CFG_DIR "config")
set(INSTALL_PLUGIN_DIR "${INSTALL_LIB_DIR}/plugin")
set(INSTALL_PLUGIN_DIR ".")
set(INSTALL_LIB_DIR ".")
set(onednn_LIBPATH ${onednn_LIBPATH}/../bin/)

View File

@ -70,6 +70,6 @@ class DebuggerProtoExporter {
BACKEND_EXPORT void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix,
const std::string &target_dir,
LocDebugDumpMode dump_location = kDebugWholeStack);
void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir);
BACKEND_EXPORT void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_MINDSPORE_PROTO_EXPORTER_H_

View File

@ -22,8 +22,9 @@
namespace mindspore {
enum LocDebugDumpMode { kDebugOff = 0, kDebugTopStack = 1, kDebugWholeStack = 2 };
void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix, const std::string &target_dir,
LocDebugDumpMode dump_location = kDebugWholeStack);
void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir);
BACKEND_EXPORT void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix,
const std::string &target_dir,
LocDebugDumpMode dump_location = kDebugWholeStack);
BACKEND_EXPORT void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_MINDSPORE_PROTO_EXPORTER_STUB_H_

View File

@ -121,7 +121,7 @@ class BACKEND_EXPORT KernelBuildInfo {
};
using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
class KernelBuildInfo::KernelBuildInfoBuilder {
class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
public:
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }

View File

@ -24,6 +24,7 @@
#include <fstream>
#include "utils/ms_context.h"
#include "utils/dlopen_macro.h"
#include "utils/os.h"
namespace mindspore {
namespace plugin_loader {
@ -31,12 +32,18 @@ void PluginLoader::LoadDynamicLib(const std::string &plugin_file, std::map<std::
MS_EXCEPTION_IF_NULL(all_handles);
void *handle = nullptr;
std::string err_msg;
#ifndef _WIN32
if (plugin_file.find("libmindspore_") == std::string::npos) {
return;
}
#else
if (plugin_file.find("mindspore_") == std::string::npos) {
return;
}
#endif
auto so_name = GetDynamicLibName(plugin_file);
#if defined(_WIN32) || defined(_WIN64)
handle = LoadLibrary(plugin_file.c_str());
handle = LoadLibraryEx(plugin_file.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
err_msg = std::to_string(GetLastError());
#else
handle = dlopen(plugin_file.c_str(), RTLD_NOW | RTLD_LOCAL);
@ -63,7 +70,7 @@ void PluginLoader::CloseDynamicLib(const std::string &dl_name, void *handle) {
}
std::string PluginLoader::GetDynamicLibName(const std::string &plugin_file) {
auto p1 = plugin_file.find_last_of('/') + 1;
auto p1 = plugin_file.find_last_of(PATH_SEPARATOR) + 1;
auto target_so = plugin_file.substr(p1);
auto pos = target_so.rfind('.');
if (pos == std::string::npos) {
@ -97,12 +104,16 @@ bool PluginLoader::GetPluginPath(std::string *file_path) {
}
cur_so_path = std::string(szPath);
#endif
auto pos = cur_so_path.find_last_of('/');
auto pos = cur_so_path.find_last_of(PATH_SEPARATOR);
if (cur_so_path.empty() || pos == std::string::npos) {
MS_LOG(INFO) << "Current so path empty or the path [" << cur_so_path << "] is invalid.";
return false;
}
#ifndef _WIN32
auto plugin_so_path = cur_so_path.substr(0, pos) + "/plugin";
#else
auto plugin_so_path = cur_so_path.substr(0, pos);
#endif
if (plugin_so_path.size() >= PATH_MAX) {
MS_LOG(INFO) << "Current path [" << plugin_so_path << "] is invalid.";
return false;
@ -145,6 +156,10 @@ void DeviceContextManager::LoadPlugin() {
MS_LOG(INFO) << "Plugin path is invalid, skip!";
return;
}
#ifdef _WIN32
auto plugin_file = plugin_path_ + "\\mindspore_gpu.dll";
plugin_loader::PluginLoader::LoadDynamicLib(plugin_file, &plugin_maps_);
#else
DIR *dir = opendir(plugin_path_.c_str());
if (dir == nullptr) {
MS_LOG(ERROR) << "Open plugin dir failed, plugin path:" << plugin_path_;
@ -152,10 +167,11 @@ void DeviceContextManager::LoadPlugin() {
}
struct dirent *entry;
while ((entry = readdir(dir)) != nullptr) {
auto plugin_file = plugin_path_ + "/" + entry->d_name;
auto plugin_file = plugin_path_ + PATH_SEPARATOR + entry->d_name;
plugin_loader::PluginLoader::LoadDynamicLib(plugin_file, &plugin_maps_);
}
(void)closedir(dir);
#endif
load_init_ = true;
}

View File

@ -60,7 +60,7 @@ class BACKEND_EXPORT DeviceContextManager {
void LoadPlugin();
std::map<std::string, void *> plugin_maps_;
bool load_init_;
bool load_init_{false};
std::string plugin_path_;
// The string converted from DeviceContextKey -> DeviceContextPtr.
@ -69,7 +69,7 @@ class BACKEND_EXPORT DeviceContextManager {
std::map<std::string, DeviceContextCreator> device_context_creators_;
};
class DeviceContextRegister {
class BACKEND_EXPORT DeviceContextRegister {
public:
DeviceContextRegister(const std::string &device_name, DeviceContextCreator &&runtime_creator) {
DeviceContextManager::GetInstance().Register(device_name, std::move(runtime_creator));

View File

@ -75,4 +75,10 @@ using pid_t = int;
#endif
#endif // _MSC_VER
#ifndef _WIN32
#define PATH_SEPARATOR '/'
#else
#define PATH_SEPARATOR '\\'
#endif
#endif // MINDSPORE_CORE_UTILS_OS_H_

View File

@ -459,6 +459,7 @@ def _set_pb_env():
def _add_cuda_path():
"""add cuda path on windows."""
if platform.system().lower() == 'windows':
if __package_name__.lower() == "mindspore_gpu":
cuda_home = os.environ.get('CUDA_PATH')
@ -466,6 +467,11 @@ def _add_cuda_path():
logger.error("mindspore-gpu on windows need CUDA_PATH, but not set it now")
else:
os.add_dll_directory(os.path.join(os.environ['CUDA_PATH'], 'bin'))
cudann_home = os.environ.get('CUDNN_HOME')
if cudann_home is None:
logger.error("mindspore-gpu on windows need CUDNN_HOME, but not set it now")
else:
os.add_dll_directory(os.path.join(os.environ['CUDNN_HOME'], 'bin'))
check_version_and_env_config()