forked from mindspore-Ecosystem/mindspore
reg op info from local config file
This commit is contained in:
parent
7b65c5483b
commit
53277f8c02
File diff suppressed because one or more lines are too long
|
@ -103,13 +103,13 @@ class OpInfo {
|
||||||
partial_flag_ = opinfo.partial_flag_;
|
partial_flag_ = opinfo.partial_flag_;
|
||||||
dynamic_format_ = opinfo.dynamic_format_;
|
dynamic_format_ = opinfo.dynamic_format_;
|
||||||
op_pattern_ = opinfo.op_pattern();
|
op_pattern_ = opinfo.op_pattern();
|
||||||
for (auto attr : opinfo.attrs_ptr()) {
|
for (const auto &attr : opinfo.attrs_ptr()) {
|
||||||
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
|
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
|
||||||
}
|
}
|
||||||
for (auto input : opinfo.inputs_ptr()) {
|
for (const auto &input : opinfo.inputs_ptr()) {
|
||||||
inputs_ptr_.push_back(std::make_shared<OpIOInfo>(*input));
|
inputs_ptr_.push_back(std::make_shared<OpIOInfo>(*input));
|
||||||
}
|
}
|
||||||
for (auto output : opinfo.outputs_ptr()) {
|
for (const auto &output : opinfo.outputs_ptr()) {
|
||||||
outputs_ptr_.push_back(std::make_shared<OpIOInfo>(*output));
|
outputs_ptr_.push_back(std::make_shared<OpIOInfo>(*output));
|
||||||
}
|
}
|
||||||
ref_infos_ = opinfo.ref_infos();
|
ref_infos_ = opinfo.ref_infos();
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <fstream>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "utils/overload.h"
|
#include "utils/overload.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
|
@ -59,7 +60,7 @@ constexpr auto kNeedCompile = "need_compile";
|
||||||
constexpr auto kShape = "shape";
|
constexpr auto kShape = "shape";
|
||||||
std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_;
|
std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_;
|
||||||
|
|
||||||
std::string ImplTypeToStr(OpImplyType impl_type) {
|
static std::string ImplTypeToStr(OpImplyType impl_type) {
|
||||||
switch (impl_type) {
|
switch (impl_type) {
|
||||||
case kTBE:
|
case kTBE:
|
||||||
return kTbe;
|
return kTbe;
|
||||||
|
@ -124,6 +125,50 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OpLib::RegOpFromLocalInfo() {
|
||||||
|
MS_LOG(INFO) << "Start";
|
||||||
|
static bool has_load = false;
|
||||||
|
if (has_load) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
has_load = true;
|
||||||
|
std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH");
|
||||||
|
if (dir.empty()) {
|
||||||
|
MS_LOG(INFO) << "MindSpore op info path does not been setted. use op info from python pass.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
char real_path[PATH_MAX] = {0};
|
||||||
|
if (dir.size() >= PATH_MAX) {
|
||||||
|
MS_LOG(ERROR) << "Op info path is invalid: " << dir;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Op info path is invalid: " << dir;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (realpath(common::SafeCStr(dir), real_path) == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Op info path is invalid: " << dir;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
MS_LOG(INFO) << "Start to read op info from local file.";
|
||||||
|
std::ifstream file(real_path);
|
||||||
|
if (!file.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "Find op info file failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string line;
|
||||||
|
while (getline(file, line)) {
|
||||||
|
if (!line.empty()) {
|
||||||
|
(void)OpLib::RegOp(line, "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "End";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
|
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
|
||||||
const std::string &impl_path) {
|
const std::string &impl_path) {
|
||||||
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
|
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
|
||||||
|
@ -160,14 +205,16 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (CheckRepetition(op_info)) {
|
||||||
|
MS_LOG(WARNING) << "This op info has been already registed. op name: " << op_info->op_name()
|
||||||
|
<< ", impl type: " << ImplTypeToStr(op_info->imply_type())
|
||||||
|
<< ", impl path: " << op_info->impl_path();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (!GetRefInfo(op_info)) {
|
if (!GetRefInfo(op_info)) {
|
||||||
MS_LOG(ERROR) << "GetRefInfo Failed";
|
MS_LOG(ERROR) << "GetRefInfo Failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CheckRepetition(op_info)) {
|
|
||||||
MS_LOG(ERROR) << "CheckRepetition Failed";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
op_info_.push_back(op_info);
|
op_info_.push_back(op_info);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -269,6 +316,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
|
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
|
||||||
|
if (!OpLib::RegOpFromLocalInfo()) {
|
||||||
|
MS_LOG(INFO) << "Warning reg local op info failed.";
|
||||||
|
}
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
bool is_gpu = (context->device_target() == kGPUDevice);
|
bool is_gpu = (context->device_target() == kGPUDevice);
|
||||||
|
@ -283,7 +333,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
|
||||||
return op_info;
|
return op_info;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
|
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
|
||||||
<< ", current op num: " << op_info_.size();
|
<< ", current op num: " << op_info_.size();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -313,17 +363,19 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
|
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
|
||||||
|
bool has_register = false;
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
for (const auto &exist_op_info : op_info_) {
|
for (const auto &exist_op_info : op_info_) {
|
||||||
MS_EXCEPTION_IF_NULL(exist_op_info);
|
MS_EXCEPTION_IF_NULL(exist_op_info);
|
||||||
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
|
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
|
||||||
exist_op_info->impl_path() != op_info->impl_path()) {
|
exist_op_info->impl_path() == op_info->impl_path()) {
|
||||||
MS_LOG(ERROR) << "Op has already exist, please use other name, op name: " << op_info->op_name()
|
MS_LOG(INFO) << "Op has already exist, please use other name, op name: " << op_info->op_name()
|
||||||
<< " op type: " << ImplTypeToStr(op_info->imply_type());
|
<< " op type: " << ImplTypeToStr(op_info->imply_type());
|
||||||
return false;
|
has_register = true;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return has_register;
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,11 +28,8 @@ class OpLib {
|
||||||
public:
|
public:
|
||||||
OpLib() = default;
|
OpLib() = default;
|
||||||
virtual ~OpLib() = default;
|
virtual ~OpLib() = default;
|
||||||
bool RegOp(const std::string &json_string, const std::string &impl_path);
|
static bool RegOp(const std::string &json_string, const std::string &impl_path);
|
||||||
static void RegOpInfo(std::shared_ptr<OpInfo> opinfo) {
|
static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace_back(opinfo); }
|
||||||
op_info_.emplace_back(opinfo);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
|
||||||
static const std::vector<std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
|
static const std::vector<std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
|
||||||
|
|
||||||
|
@ -40,6 +37,7 @@ class OpLib {
|
||||||
static std::vector<std::shared_ptr<OpInfo>> op_info_;
|
static std::vector<std::shared_ptr<OpInfo>> op_info_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
static bool RegOpFromLocalInfo();
|
||||||
static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path);
|
static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path);
|
||||||
static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
|
static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
|
||||||
const std::shared_ptr<OpInfo> &op_info);
|
const std::shared_ptr<OpInfo> &op_info);
|
||||||
|
|
|
@ -323,7 +323,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
|
|
||||||
(void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
|
(void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def("reg_op", &OpLib::RegOp, "Register op info.");
|
.def_static("reg_op", &OpLib::RegOp, "Register op info.");
|
||||||
#ifdef ENABLE_GPU_COLLECTIVE
|
#ifdef ENABLE_GPU_COLLECTIVE
|
||||||
(void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective,
|
(void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective,
|
||||||
"Init gpu collective communication mode.");
|
"Init gpu collective communication mode.");
|
||||||
|
|
Loading…
Reference in New Issue