fix windows gpu

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2023-02-02 15:45:43 +08:00
parent 7cf3697fa6
commit 836f0355ac
13 changed files with 65 additions and 39 deletions

View File

@ -199,10 +199,9 @@ if(ENABLE_GPU)
COMPONENT mindspore
)
install(
TARGETS mindspore_gpu LIBRARY
TARGETS mindspore_gpu
DESTINATION ${INSTALL_PLUGIN_DIR}
COMPONENT mindspore
NAMELINK_SKIP
)
endif()

View File

@ -1,4 +1,5 @@
file(GLOB_RECURSE KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"kernel_factory.cc"
"kernel_build_info.cc"
"kernel.cc"
"common_utils.cc"

View File

@ -14,11 +14,23 @@
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
#include "kernel/kernel_factory.h"
namespace mindspore::kernel {
BiShengKernelFactory &BiShengKernelFactory::GetInstance() {
static BiShengKernelFactory instance;
return instance;
FactoryBase *FactoryBase::GetInstance(const std::string &name) {
auto iter = Map().find(name);
if (iter != Map().end()) {
return iter->second.get();
}
return nullptr;
}
void FactoryBase::CreateFactory(const std::string &name, std::unique_ptr<FactoryBase> &&factory) {
Map().emplace(name, std::move(factory));
}
std::map<std::string, std::unique_ptr<FactoryBase>> &FactoryBase::Map() {
static std::map<std::string, std::unique_ptr<FactoryBase>> map;
return map;
}
} // namespace mindspore::kernel

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,25 +14,32 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_BISHENG_KERNEL_FACTORY_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_BISHENG_KERNEL_FACTORY_H
#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_KERNEL_KERNEL_FACTORY_H_
#include <string>
#include <algorithm>
#include <functional>
#include <vector>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/backend/visible.h"
#include "plugin/factory/ms_factory.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace kernel {
class BiShengKernelMod;
class BACKEND_EXPORT BiShengKernelFactory : public Factory<BiShengKernelMod> {
class BACKEND_EXPORT FactoryBase {
public:
static BiShengKernelFactory &GetInstance();
virtual ~FactoryBase() = default;
protected:
static FactoryBase *GetInstance(const std::string &name);
static void CreateFactory(const std::string &name, std::unique_ptr<FactoryBase> &&factory);
private:
static std::map<std::string, std::unique_ptr<FactoryBase>> &Map();
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ASCEND_BISHENG_KERNEL_FACTORY_H
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_FACTORY_H_

View File

@ -17,7 +17,7 @@ if(${try_result})
message("Bisheng toolchain seems to work.")
add_subdirectory(impl)
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
list(REMOVE_ITEM SRC_LIST "bisheng_kernel_build.cc" "custom_bisheng_kernel.cc" "bisheng_kernel_factory.cc")
list(REMOVE_ITEM SRC_LIST "bisheng_kernel_build.cc" "custom_bisheng_kernel.cc")
add_library(bisheng_kernels SHARED ${SRC_LIST})
target_link_libraries(bisheng_kernels PRIVATE mindspore_ascend bisheng_kernels_impl)
set_target_properties(bisheng_kernels PROPERTIES INSTALL_RPATH $ORIGIN)

View File

@ -16,7 +16,6 @@
#include "plugin/device/ascend/kernel/bisheng/add_bisheng_kernel.h"
#include <algorithm>
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
#include "plugin/device/ascend/kernel/bisheng/bisheng_op_info.h"
#include "plugin/device/ascend/kernel/bisheng/impl/add.h"

View File

@ -22,7 +22,7 @@
#include <map>
#include "plugin/device/ascend/kernel/bisheng/custom_bisheng_kernel.h"
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_mod.h"
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "include/common/utils/anfalgo.h"
#include "utils/dlopen_macro.h"
@ -54,7 +54,7 @@ KernelModPtr BiShengOpBuild(const AnfNodePtr &anf_node) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &kernel_name = common::AnfAlgo::GetCNodeName(cnode);
if (!BiShengKernelFactory::GetInstance().IsRegistered(kernel_name)) {
if (!Factory<BiShengKernelMod>::Instance().IsRegistered(kernel_name)) {
MS_LOG(INFO) << "Bisheng custom op " << kernel_name;
auto kernel_mod_ptr = std::make_shared<CustomBiShengKernel>(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
@ -66,7 +66,7 @@ KernelModPtr BiShengOpBuild(const AnfNodePtr &anf_node) {
}
MS_LOG(INFO) << "Bisheng internal op " << kernel_name;
auto kernel_mod = BiShengKernelFactory::GetInstance().Create(kernel_name);
auto kernel_mod = Factory<BiShengKernelMod>::Instance().Create(kernel_name);
auto args = AbstractArgsFromCNode(cnode, false);
auto inputs_tensor_map = std::map<uint32_t, tensor::TensorPtr>();
SetInputsByConstInputs(cnode, &inputs_tensor_map);

View File

@ -25,7 +25,7 @@
#include "include/api/format.h"
#include "include/api/data_type.h"
#include "kernel/oplib/opinfo.h"
#include "plugin/device/ascend/kernel/bisheng/bisheng_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace cl::sycl::detail::half_impl {
class half;
@ -33,6 +33,7 @@ class half;
namespace mindspore::kernel {
using half = cl::sycl::detail::half_impl::half;
class BiShengKernelMod;
#define REG(Clazz) const BishengOpInfoRegister<Clazz> Clazz::reg_ = BishengOpInfoRegister<Clazz>()
@ -380,7 +381,7 @@ class BishengOpInfoRegister : public BishengOpInfoRegisterHelper {
BishengOpInfoRegister() : BishengOpInfoRegisterHelper(), func_list_(T::func_list_) {}
const BishengOpInfoRegister<T> &End() {
BishengOpInfoRegisterHelper::End();
BiShengKernelFactory::GetInstance().Register(op_info_->op_name(),
Factory<BiShengKernelMod>::Instance().Register(op_info_->op_name(),
std::move([]() { return std::make_shared<T>(); }));
return *this;
}

View File

@ -24,22 +24,28 @@
#include <string>
#include <utility>
#include <vector>
#include "include/backend/visible.h"
#include "kernel/kernel_factory.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace kernel {
template <class C>
class BACKEND_EXPORT Factory {
class Factory : public FactoryBase {
using CreatorFunc = std::function<std::shared_ptr<C>()>;
public:
Factory(const Factory &) = delete;
void operator=(const Factory &) = delete;
static Factory &Instance() {
static Factory instance;
return instance;
static Factory<C> &Instance() {
std::string key = typeid(C).name();
FactoryBase *instance = FactoryBase::GetInstance(key);
if (instance == nullptr) {
FactoryBase::CreateFactory(key, std::make_unique<Factory<C>>());
instance = FactoryBase::GetInstance(key);
}
MS_EXCEPTION_IF_NULL(instance);
return *static_cast<Factory<C> *>(instance);
}
void Register(const std::string &name, CreatorFunc &&creator) {
@ -71,7 +77,6 @@ class BACKEND_EXPORT Factory {
return false;
}
protected:
Factory() = default;
~Factory() = default;
@ -80,7 +85,7 @@ class BACKEND_EXPORT Factory {
};
template <class C>
class BACKEND_EXPORT KernelRegistrar {
class KernelRegistrar {
public:
explicit KernelRegistrar(const std::string &name, std::function<std::shared_ptr<C>()> creator) noexcept {
Factory<C>::Instance().Register(name, std::move(creator));

View File

@ -319,7 +319,7 @@ bool PluginLoader::GetPluginPath(std::string *file_path) {
#ifndef _WIN32
auto plugin_so_path = cur_so_path.substr(0, pos) + "/plugin";
#else
auto plugin_so_path = cur_so_path.substr(0, pos) + "\\bin";
auto plugin_so_path = cur_so_path.substr(0, pos);
#endif
if (plugin_so_path.size() >= PATH_MAX) {
MS_LOG(INFO) << "Current path [" << plugin_so_path << "] is invalid.";

View File

@ -113,6 +113,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
${CCSRC_DIR}/runtime/hardware/device_type.cc
${CCSRC_DIR}/kernel/kernel_build_info.cc
${CCSRC_DIR}/kernel/common_utils.cc
${CCSRC_DIR}/kernel/kernel_factory.cc
${CCSRC_DIR}/kernel/kernel.cc
${CCSRC_DIR}/kernel/kash/kernel_pack.cc
${CCSRC_DIR}/kernel/oplib/oplib.cc

View File

@ -13,6 +13,7 @@ set(CCSRC_SRC
${CCSRC_DIR}/backend/common/optimizer/visit.cc
${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc
${CCSRC_DIR}/backend/operator/ops_backend_infer_function.cc
${CCSRC_DIR}/kernel/kernel_factory.cc
)
if(NOT WIN32)