forked from mindspore-Ecosystem/mindspore
!9393 modify c_ops register when loadding mindir
From: @wangnan39 Reviewed-by: @guoqi1024,@kingxian Signed-off-by: @kingxian
This commit is contained in:
commit
ab15d11d9c
|
@ -60,7 +60,7 @@ MSInferSession::~MSInferSession() = default;
|
|||
|
||||
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
Py_Initialize();
|
||||
auto graph = RunLoadMindIR(file_name);
|
||||
auto graph = mindspore::LoadMindIR(file_name);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
|
|
|
@ -74,8 +74,6 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
|
||||
.def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
|
||||
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
|
||||
.def("convert_funcgraph_to_mindir", &ExecutorPy::ConvertFuncGraphToMindIR, py::arg("graph"),
|
||||
"Convert FuncGraph to MindIR proto.")
|
||||
.def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
|
||||
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
|
||||
.def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
|
||||
|
@ -110,7 +108,6 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend.");
|
||||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load MindIR as Graph.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -45,7 +45,6 @@
|
|||
#include "debug/draw.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
|
@ -104,16 +103,6 @@ void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
py::bytes ExecutorPy::ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr) {
|
||||
std::string proto_str = GetBinaryProtoString(fg_ptr);
|
||||
if (proto_str.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Graph proto is empty.";
|
||||
}
|
||||
return proto_str;
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::RunLoadMindIR(file_name); }
|
||||
|
||||
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
|
||||
MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
|
|
|
@ -82,7 +82,6 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
ResourcePtr GetResource(const std::string &phase);
|
||||
FuncGraphPtr GetFuncGraph(const std::string &phase);
|
||||
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
|
||||
py::bytes ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr);
|
||||
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
|
||||
bool HasCompiled(const std::string &phase) const;
|
||||
|
||||
|
@ -139,7 +138,6 @@ void ClearResAtexit();
|
|||
void ReleaseGeTsd();
|
||||
|
||||
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
|
||||
// init and exec dataset sub graph
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
|
||||
|
|
|
@ -50,5 +50,5 @@ AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer);
|
||||
REGISTER_PRIMITIVE_C(TensorAdd);
|
||||
REGISTER_PRIMITIVE_C(kNameTensorAdd, TensorAdd);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -102,5 +102,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer);
|
||||
REGISTER_PRIMITIVE_C(AvgPool);
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -107,7 +107,7 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
Conv2D::Conv2D() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
|
||||
Conv2D::Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); }
|
||||
|
||||
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
|
||||
const std::string &pad_mode, const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
|
||||
|
@ -193,5 +193,5 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
Conv2dInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
|
||||
REGISTER_PRIMITIVE_C(Conv2D);
|
||||
REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
constexpr auto kConv2DName = "Conv2D";
|
||||
constexpr auto kNameConv2D = "Conv2D";
|
||||
class Conv2D : public PrimitiveC {
|
||||
public:
|
||||
Conv2D();
|
||||
|
|
|
@ -206,4 +206,5 @@ AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const
|
|||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameDepthWiseConv2D, DepthWiseConv2D);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,6 @@ OpPrimCRegister &OpPrimCRegister::GetInstance() {
|
|||
}
|
||||
|
||||
std::map<std::string, OpPrimCDefineFunc> OpPrimCRegister::GetPrimCMap() { return op_primc_fns_; }
|
||||
void OpPrimCRegister::SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn) { op_primc_fns_[name] = fn; }
|
||||
void OpPrimCRegister::SetPrimCMap(const std::string &kname, const OpPrimCDefineFunc &fn) { op_primc_fns_[kname] = fn; }
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,7 @@ class OpPrimCRegister {
|
|||
~OpPrimCRegister() {}
|
||||
static OpPrimCRegister &GetInstance();
|
||||
std::map<std::string, OpPrimCDefineFunc> GetPrimCMap();
|
||||
void SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn);
|
||||
void SetPrimCMap(const std::string &kname, const OpPrimCDefineFunc &fn);
|
||||
|
||||
private:
|
||||
OpPrimCRegister() {}
|
||||
|
@ -50,17 +50,17 @@ class OpPrimCRegister {
|
|||
|
||||
class OpPrimCRegisterHelper {
|
||||
public:
|
||||
OpPrimCRegisterHelper(const std::string &name, const OpPrimCDefineFunc &fn) {
|
||||
OpPrimCRegister::GetInstance().SetPrimCMap(name, fn);
|
||||
OpPrimCRegisterHelper(const std::string &kname, const OpPrimCDefineFunc &fn) {
|
||||
OpPrimCRegister::GetInstance().SetPrimCMap(kname, fn);
|
||||
}
|
||||
~OpPrimCRegisterHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_C(name) \
|
||||
std::shared_ptr<PrimitiveC> GetDefaultPrimC##name() { \
|
||||
auto out = std::make_shared<name>(); \
|
||||
return out; \
|
||||
} \
|
||||
OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
|
||||
#define REGISTER_PRIMITIVE_C(kname, primc) \
|
||||
std::shared_ptr<PrimitiveC> GetDefaultPrimC##primc() { \
|
||||
auto out = std::make_shared<primc>(); \
|
||||
return out; \
|
||||
} \
|
||||
OpPrimCRegisterHelper primc_gen_##kname(kname, GetDefaultPrimC##primc);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
|
||||
|
|
|
@ -49,5 +49,5 @@ AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
Relu6InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer);
|
||||
REGISTER_PRIMITIVE_C(Relu6);
|
||||
REGISTER_PRIMITIVE_C(kNameRelu6, Relu6);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,5 +41,5 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer);
|
||||
REGISTER_PRIMITIVE_C(Reshape);
|
||||
REGISTER_PRIMITIVE_C(kNameReshape, Reshape);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -75,5 +75,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer);
|
||||
REGISTER_PRIMITIVE_C(Softmax);
|
||||
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,5 +76,5 @@ AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer);
|
||||
REGISTER_PRIMITIVE_C(Squeeze);
|
||||
REGISTER_PRIMITIVE_C(kNameSqueeze, Squeeze);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,7 +71,7 @@ std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
|
|||
return buf;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name) {
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) {
|
||||
auto graphBuf = ReadProtoFile(file_name);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str();
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name);
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size);
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue