!48340 fix windows gpu
Merge pull request !48340 from zhoufeng/fix-windows-gpu
This commit is contained in:
commit
f7c3cfad48
|
@ -199,10 +199,9 @@ if(ENABLE_GPU)
|
|||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS mindspore_gpu LIBRARY
|
||||
TARGETS mindspore_gpu
|
||||
DESTINATION ${INSTALL_PLUGIN_DIR}
|
||||
COMPONENT mindspore
|
||||
NAMELINK_SKIP
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
file(GLOB_RECURSE KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"kernel_factory.cc"
|
||||
"kernel_build_info.cc"
|
||||
"kernel.cc"
|
||||
"common_utils.cc"
|
||||
|
|
|
@ -14,11 +14,23 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
|
||||
#include "kernel/kernel_factory.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
BiShengKernelFactory &BiShengKernelFactory::GetInstance() {
|
||||
static BiShengKernelFactory instance;
|
||||
return instance;
|
||||
FactoryBase *FactoryBase::GetInstance(const std::string &name) {
|
||||
auto iter = Map().find(name);
|
||||
if (iter != Map().end()) {
|
||||
return iter->second.get();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void FactoryBase::CreateFactory(const std::string &name, std::unique_ptr<FactoryBase> &&factory) {
|
||||
Map().emplace(name, std::move(factory));
|
||||
}
|
||||
|
||||
std::map<std::string, std::unique_ptr<FactoryBase>> &FactoryBase::Map() {
|
||||
static std::map<std::string, std::unique_ptr<FactoryBase>> map;
|
||||
return map;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,25 +14,32 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_BISHENG_KERNEL_FACTORY_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_BISHENG_KERNEL_FACTORY_H
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_KERNEL_FACTORY_H_
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "include/backend/visible.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BiShengKernelMod;
|
||||
|
||||
class BACKEND_EXPORT BiShengKernelFactory : public Factory<BiShengKernelMod> {
|
||||
class BACKEND_EXPORT FactoryBase {
|
||||
public:
|
||||
static BiShengKernelFactory &GetInstance();
|
||||
virtual ~FactoryBase() = default;
|
||||
|
||||
protected:
|
||||
static FactoryBase *GetInstance(const std::string &name);
|
||||
static void CreateFactory(const std::string &name, std::unique_ptr<FactoryBase> &&factory);
|
||||
|
||||
private:
|
||||
static std::map<std::string, std::unique_ptr<FactoryBase>> &Map();
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_BISHENG_KERNEL_FACTORY_H
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_FACTORY_H_
|
|
@ -17,7 +17,7 @@ if(${try_result})
|
|||
message("Bisheng toolchain seems to work.")
|
||||
add_subdirectory(impl)
|
||||
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
list(REMOVE_ITEM SRC_LIST "bisheng_kernel_build.cc" "custom_bisheng_kernel.cc" "bisheng_kernel_factory.cc")
|
||||
list(REMOVE_ITEM SRC_LIST "bisheng_kernel_build.cc" "custom_bisheng_kernel.cc")
|
||||
add_library(bisheng_kernels SHARED ${SRC_LIST})
|
||||
target_link_libraries(bisheng_kernels PRIVATE mindspore_ascend bisheng_kernels_impl)
|
||||
set_target_properties(bisheng_kernels PROPERTIES INSTALL_RPATH $ORIGIN)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "plugin/device/ascend/kernel/bisheng/add_bisheng_kernel.h"
|
||||
#include <algorithm>
|
||||
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
|
||||
#include "plugin/device/ascend/kernel/bisheng/bisheng_op_info.h"
|
||||
#include "plugin/device/ascend/kernel/bisheng/impl/add.h"
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <map>
|
||||
#include "plugin/device/ascend/kernel/bisheng/custom_bisheng_kernel.h"
|
||||
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_mod.h"
|
||||
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/dlopen_macro.h"
|
||||
|
||||
|
@ -54,7 +54,7 @@ KernelModPtr BiShengOpBuild(const AnfNodePtr &anf_node) {
|
|||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &kernel_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
if (!BiShengKernelFactory::GetInstance().IsRegistered(kernel_name)) {
|
||||
if (!Factory<BiShengKernelMod>::Instance().IsRegistered(kernel_name)) {
|
||||
MS_LOG(INFO) << "Bisheng custom op " << kernel_name;
|
||||
auto kernel_mod_ptr = std::make_shared<CustomBiShengKernel>(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
|
@ -66,7 +66,7 @@ KernelModPtr BiShengOpBuild(const AnfNodePtr &anf_node) {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Bisheng internal op " << kernel_name;
|
||||
auto kernel_mod = BiShengKernelFactory::GetInstance().Create(kernel_name);
|
||||
auto kernel_mod = Factory<BiShengKernelMod>::Instance().Create(kernel_name);
|
||||
auto args = AbstractArgsFromCNode(cnode, false);
|
||||
auto inputs_tensor_map = std::map<uint32_t, tensor::TensorPtr>();
|
||||
SetInputsByConstInputs(cnode, &inputs_tensor_map);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "include/api/format.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace cl::sycl::detail::half_impl {
|
||||
class half;
|
||||
|
@ -33,6 +33,7 @@ class half;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
using half = cl::sycl::detail::half_impl::half;
|
||||
class BiShengKernelMod;
|
||||
|
||||
#define REG(Clazz) const BishengOpInfoRegister<Clazz> Clazz::reg_ = BishengOpInfoRegister<Clazz>()
|
||||
|
||||
|
@ -380,8 +381,8 @@ class BishengOpInfoRegister : public BishengOpInfoRegisterHelper {
|
|||
BishengOpInfoRegister() : BishengOpInfoRegisterHelper(), func_list_(T::func_list_) {}
|
||||
const BishengOpInfoRegister<T> &End() {
|
||||
BishengOpInfoRegisterHelper::End();
|
||||
BiShengKernelFactory::GetInstance().Register(op_info_->op_name(),
|
||||
std::move([]() { return std::make_shared<T>(); }));
|
||||
Factory<BiShengKernelMod>::Instance().Register(op_info_->op_name(),
|
||||
std::move([]() { return std::make_shared<T>(); }));
|
||||
return *this;
|
||||
}
|
||||
BishengOpInfoRegister<T> &OpName(const std::string &name) {
|
||||
|
|
|
@ -24,22 +24,28 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "include/backend/visible.h"
|
||||
#include "kernel/kernel_factory.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <class C>
|
||||
class BACKEND_EXPORT Factory {
|
||||
class Factory : public FactoryBase {
|
||||
using CreatorFunc = std::function<std::shared_ptr<C>()>;
|
||||
|
||||
public:
|
||||
Factory(const Factory &) = delete;
|
||||
void operator=(const Factory &) = delete;
|
||||
|
||||
static Factory &Instance() {
|
||||
static Factory instance;
|
||||
return instance;
|
||||
static Factory<C> &Instance() {
|
||||
std::string key = typeid(C).name();
|
||||
FactoryBase *instance = FactoryBase::GetInstance(key);
|
||||
if (instance == nullptr) {
|
||||
FactoryBase::CreateFactory(key, std::make_unique<Factory<C>>());
|
||||
instance = FactoryBase::GetInstance(key);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(instance);
|
||||
return *static_cast<Factory<C> *>(instance);
|
||||
}
|
||||
|
||||
void Register(const std::string &name, CreatorFunc &&creator) {
|
||||
|
@ -71,7 +77,6 @@ class BACKEND_EXPORT Factory {
|
|||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
Factory() = default;
|
||||
~Factory() = default;
|
||||
|
||||
|
@ -80,7 +85,7 @@ class BACKEND_EXPORT Factory {
|
|||
};
|
||||
|
||||
template <class C>
|
||||
class BACKEND_EXPORT KernelRegistrar {
|
||||
class KernelRegistrar {
|
||||
public:
|
||||
explicit KernelRegistrar(const std::string &name, std::function<std::shared_ptr<C>()> creator) noexcept {
|
||||
Factory<C>::Instance().Register(name, std::move(creator));
|
||||
|
|
|
@ -319,7 +319,7 @@ bool PluginLoader::GetPluginPath(std::string *file_path) {
|
|||
#ifndef _WIN32
|
||||
auto plugin_so_path = cur_so_path.substr(0, pos) + "/plugin";
|
||||
#else
|
||||
auto plugin_so_path = cur_so_path.substr(0, pos) + "\\bin";
|
||||
auto plugin_so_path = cur_so_path.substr(0, pos);
|
||||
#endif
|
||||
if (plugin_so_path.size() >= PATH_MAX) {
|
||||
MS_LOG(INFO) << "Current path [" << plugin_so_path << "] is invalid.";
|
||||
|
|
|
@ -113,6 +113,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
|||
${CCSRC_DIR}/runtime/hardware/device_type.cc
|
||||
${CCSRC_DIR}/kernel/kernel_build_info.cc
|
||||
${CCSRC_DIR}/kernel/common_utils.cc
|
||||
${CCSRC_DIR}/kernel/kernel_factory.cc
|
||||
${CCSRC_DIR}/kernel/kernel.cc
|
||||
${CCSRC_DIR}/kernel/kash/kernel_pack.cc
|
||||
${CCSRC_DIR}/kernel/oplib/oplib.cc
|
||||
|
|
|
@ -13,6 +13,7 @@ set(CCSRC_SRC
|
|||
${CCSRC_DIR}/backend/common/optimizer/visit.cc
|
||||
${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc
|
||||
${CCSRC_DIR}/backend/operator/ops_backend_infer_function.cc
|
||||
${CCSRC_DIR}/kernel/kernel_factory.cc
|
||||
)
|
||||
|
||||
if(NOT WIN32)
|
||||
|
|
|
@ -520,12 +520,12 @@ def check_version_and_env_config():
|
|||
logger.warning("Pre-Load Lirary libgomp.so.1 failed, which might cause TLS memory allocation failure. If "
|
||||
"the failure occurs, you can find a solution in FAQ in "
|
||||
"https://www.mindspore.cn/docs/en/master/faq/installation.html.")
|
||||
if not os.getenv("MS_DEV_CLOSE_VERSION_CHECK") is None:
|
||||
return
|
||||
MSContext.get_instance().register_check_env_callback(check_env)
|
||||
MSContext.get_instance().register_set_env_callback(set_env)
|
||||
MSContext.get_instance().set_param(ms_ctx_param.device_target,
|
||||
MSContext.get_instance().get_param(ms_ctx_param.device_target))
|
||||
if not os.getenv("MS_DEV_CLOSE_VERSION_CHECK") is None:
|
||||
return
|
||||
MSContext.get_instance().register_check_env_callback(check_env)
|
||||
MSContext.get_instance().register_set_env_callback(set_env)
|
||||
MSContext.get_instance().set_param(ms_ctx_param.device_target,
|
||||
MSContext.get_instance().get_param(ms_ctx_param.device_target))
|
||||
|
||||
|
||||
def _set_pb_env():
|
||||
|
|
Loading…
Reference in New Issue