forked from mindspore-Ecosystem/mindspore
Enable mindir to load initialize weight from python
This commit is contained in:
parent
821aba8907
commit
21df240f23
|
@ -32,7 +32,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
|
|||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
|
||||
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph);
|
||||
|
||||
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
|
||||
|
||||
|
|
|
@ -207,7 +207,7 @@ void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_grap
|
|||
MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return;
|
||||
}
|
||||
ModelProtoPtr fg_model = GetBinaryProto(func_graph, false);
|
||||
ModelProtoPtr fg_model = GetBinaryProto(func_graph);
|
||||
if (fg_model == nullptr) {
|
||||
MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
|
||||
fout.close();
|
||||
|
|
|
@ -110,7 +110,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_enable_tuple_broaden", &GraphExecutorPy::set_enable_tuple_broaden,
|
||||
py::arg("enable_tuple_broaden") = py::bool_(false), "Set tuple broaden enable.")
|
||||
.def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files,
|
||||
py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.");
|
||||
py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.")
|
||||
.def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(),
|
||||
"Set values of weights.");
|
||||
|
||||
(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());
|
||||
|
||||
|
|
|
@ -281,7 +281,17 @@ bool CheckDepFilesHashConsistency(const std::string ¤t_dep_files_hash) {
|
|||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx) {
|
||||
std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
|
||||
std::map<string, ValuePtr> ret{};
|
||||
for (auto weight = weights.begin(); weight != weights.end(); ++weight) {
|
||||
auto weight_name = py::cast<std::string>(weight->first);
|
||||
auto weight_value = parse::data_converter::PyDataToValue(py::cast<py::object>(weight->second));
|
||||
ret[weight_name] = weight_value;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) {
|
||||
std::string compile_cache_path = GetCompileCachePath(idx);
|
||||
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
|
@ -297,10 +307,11 @@ FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx) {
|
|||
}
|
||||
MindIRLoader mindir_loader;
|
||||
mindir_loader.set_need_renormalize(false);
|
||||
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
|
||||
return mindir_loader.LoadMindIR(realpath.value());
|
||||
}
|
||||
|
||||
FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_name) {
|
||||
FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &weights, const std::string &queue_name) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
// Compare the dependency files hash.
|
||||
if (!CheckDepFilesHashConsistency(resource->compile_cache_dep_files_hash())) {
|
||||
|
@ -309,7 +320,7 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const std::string &
|
|||
}
|
||||
|
||||
// Load the compilation cache file.
|
||||
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id());
|
||||
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id(), weights);
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
|
||||
return nullptr;
|
||||
|
@ -352,7 +363,7 @@ bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, size_t idx) {
|
|||
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
ModelProtoPtr fg_model = GetBinaryProto(fg, true);
|
||||
ModelProtoPtr fg_model = GetBinaryProto(fg);
|
||||
if (fg_model == nullptr) {
|
||||
MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
|
||||
fout.close();
|
||||
|
@ -920,7 +931,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
resource->set_enable_compile_cache(true);
|
||||
resource->set_compile_cache_id(GetCompileCacheGraphId());
|
||||
resource->set_compile_cache_dep_files_hash(GetCompileDepFilesHash(compile_cache_dep_files_));
|
||||
resource->set_func_graph(GetCachedFuncGraph(resource, queue_name_));
|
||||
resource->set_func_graph(GetCachedFuncGraph(resource, weights_, queue_name_));
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t2 = GetTime();
|
||||
MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1);
|
||||
|
|
|
@ -110,6 +110,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
void set_compile_cache_dep_files(const py::list &compile_cache_dep_files) {
|
||||
compile_cache_dep_files_ = compile_cache_dep_files;
|
||||
}
|
||||
void set_weights_values(const py::dict &weights) { weights_ = weights; }
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
static bool GetDebugTerminate() { return debugger_terminate_; }
|
||||
static void DebugTerminate(bool val, bool exit_success) {
|
||||
|
@ -145,6 +146,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
std::string queue_name_;
|
||||
bool enable_tuple_broaden_{false};
|
||||
py::list compile_cache_dep_files_;
|
||||
py::dict weights_;
|
||||
};
|
||||
using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ class IrExporter {
|
|||
explicit IrExporter(IrExportBuilderPtr builder) : builder_(std::move(builder)) {}
|
||||
virtual ~IrExporter() = default;
|
||||
std::string GetDumpString(const FuncGraphPtr &func_graph);
|
||||
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
||||
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
IrExportBuilderPtr builder_;
|
||||
|
@ -95,13 +95,11 @@ class IrExportBuilder {
|
|||
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
|
||||
std::string GetProtoString() const;
|
||||
void BuildModelInfo();
|
||||
bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
||||
bool BuildModel(const FuncGraphPtr &func_graph);
|
||||
ModelProtoPtr Model() { return model_; }
|
||||
|
||||
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data = false);
|
||||
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data = false);
|
||||
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
|
@ -148,6 +146,7 @@ class IrExportBuilder {
|
|||
std::set<std::string> nodeName_;
|
||||
size_t node_index_{0};
|
||||
size_t shape_index_{0};
|
||||
|
||||
bool top_graph{true};
|
||||
};
|
||||
|
||||
|
@ -161,7 +160,7 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
|||
return builder_->GetProtoString();
|
||||
}
|
||||
|
||||
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
|
||||
if ((builder_ == nullptr) || (func_graph == nullptr)) {
|
||||
MS_LOG(EXCEPTION) << "Input params is null.";
|
||||
}
|
||||
|
@ -170,7 +169,7 @@ ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save
|
|||
builder_->BuildModelInfo();
|
||||
|
||||
// Export model and return string
|
||||
if (!builder_->BuildModel(func_graph, save_tensor_data)) {
|
||||
if (!builder_->BuildModel(func_graph)) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder_->Model();
|
||||
|
@ -190,7 +189,7 @@ void IrExportBuilder::BuildModelInfo() {
|
|||
model_->set_little_endian(common::IsLittleByteOrder());
|
||||
}
|
||||
|
||||
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
mind_ir::GraphProto *graph_proto = model_->mutable_graph();
|
||||
graph_proto->set_name(func_graph->ToString());
|
||||
|
@ -201,7 +200,7 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|||
// Build the main funcGraph
|
||||
(void)nodeName_.insert(func_graph->ToString());
|
||||
top_graph = true;
|
||||
if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) {
|
||||
if (!BuildFuncGraph(func_graph, graph_proto)) {
|
||||
MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -222,7 +221,7 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|||
(void)nodeName_.insert(fg->ToString());
|
||||
(void)graphVisited.insert(fg);
|
||||
auto graph = model_->add_functions();
|
||||
if (!BuildFuncGraph(fg, graph, save_tensor_data)) {
|
||||
if (!BuildFuncGraph(fg, graph)) {
|
||||
MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -233,14 +232,13 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data) {
|
||||
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
// Export funcGraph name.
|
||||
graph_proto->set_name(func_graph->ToString());
|
||||
// Export parameters
|
||||
// 1. parameters should be mapped to ValueInfoProto
|
||||
// 2. parameters with default value should be mapped to Initializer
|
||||
if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) {
|
||||
if (!BuildParameters(func_graph, graph_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameters failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -249,8 +247,7 @@ bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::Gr
|
|||
return BuildNodes(func_graph, graph_proto);
|
||||
}
|
||||
|
||||
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data) {
|
||||
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_proto);
|
||||
for (auto &item : func_graph->parameters()) {
|
||||
|
@ -269,10 +266,6 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|||
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
|
||||
return false;
|
||||
}
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
|
||||
if (tensor && save_tensor_data) {
|
||||
parameter_proto->set_raw_data(tensor->data_c(), static_cast<size_t>(tensor->data().nbytes()));
|
||||
}
|
||||
} else {
|
||||
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||
input_proto->set_name(param_name);
|
||||
|
@ -1077,9 +1070,9 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|||
return exporter->GetDumpString(func_graph);
|
||||
}
|
||||
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) {
|
||||
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
|
||||
auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
|
||||
auto result = exporter->GetDumpProto(func_graph);
|
||||
return result;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -274,6 +274,7 @@ class _MindsporeFunctionExecutor:
|
|||
if self.obj is None:
|
||||
is_compile = self._graph_executor.compile(self.fn, args_list, phase, True)
|
||||
else:
|
||||
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
||||
is_compile = self._graph_executor.compile(self.obj, args_list, phase, True)
|
||||
if not is_compile:
|
||||
raise RuntimeError("Executor compile failed.")
|
||||
|
@ -651,7 +652,7 @@ class _CellGraphExecutor:
|
|||
|
||||
enable_ge = context.get_context("enable_ge")
|
||||
use_vm = self._use_vm_mode()
|
||||
|
||||
self._graph_executor.set_weights_values(obj.parameters_dict())
|
||||
result = self._graph_executor.compile(obj, args_list, phase, use_vm)
|
||||
obj.compile_cache.add(phase)
|
||||
if not result:
|
||||
|
|
|
@ -98,6 +98,7 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
|
|||
// Default parameter can be shared since it is readonly.
|
||||
new_param->set_default_param(old_param->default_param());
|
||||
}
|
||||
new_param->set_is_top_graph_param(old_param->is_top_graph_param());
|
||||
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
|
||||
new_param->set_scope(scope);
|
||||
repl_node_[node] = std::move(new_param);
|
||||
|
|
|
@ -312,15 +312,17 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
|||
MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
node->set_default_param(tensor_info);
|
||||
} else if (parameter_proto.has_external_data()) {
|
||||
auto ret = GetTensorDataFromExternal(parameter_proto, tensor_info);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
node->set_default_param(tensor_info);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Parameter will load initialized data.";
|
||||
}
|
||||
|
||||
node->set_default_param(tensor_info);
|
||||
|
||||
anfnode_build_map_[parameter_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
@ -974,7 +976,7 @@ bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) {
|
|||
|
||||
// If the operator is not a primitive, the abstract will been set to null.
|
||||
// Because there are not some operators in front end, the abstract of primitive should be reserved.
|
||||
if (prim == nullptr) {
|
||||
if (prim == nullptr && need_renormalize()) {
|
||||
cnode_ptr->set_abstract(nullptr);
|
||||
return true;
|
||||
}
|
||||
|
@ -1214,7 +1216,39 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m
|
|||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
||||
bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
|
||||
const std::map<std::string, ValuePtr> &weights) {
|
||||
size_t hyper_param_count = 0;
|
||||
auto parameters = topGraph->parameters();
|
||||
for (int i = parameters.size() - 1; i >= 0; --i) {
|
||||
auto parameter = parameters[i]->cast<ParameterPtr>();
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "AnfNode " << parameters[i]->DebugString() << " should be Parameter.";
|
||||
return false;
|
||||
}
|
||||
auto type = parameter->Type();
|
||||
if (type == nullptr) {
|
||||
MS_LOG(ERROR) << "Parameter " << parameter->DebugString() << " has no type.";
|
||||
return false;
|
||||
}
|
||||
if (!type->isa<RefType>()) {
|
||||
break;
|
||||
}
|
||||
auto parameter_name = parameter->name();
|
||||
auto weights_iter = weights.find(parameter_name);
|
||||
if (weights_iter == weights.end()) {
|
||||
MS_LOG(ERROR) << "Find initial weight value for " << parameter_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
parameter->set_default_param(weights_iter->second);
|
||||
hyper_param_count++;
|
||||
}
|
||||
topGraph->set_hyper_param_count(hyper_param_count);
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
|
||||
const std::map<std::string, ValuePtr> &weights) {
|
||||
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
|
@ -1252,6 +1286,14 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
|||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!weights.empty()) {
|
||||
if (!SetValueForTopGraphParameter(dstGraph, weights)) {
|
||||
MS_LOG(ERROR) << "Set value for top graph fail.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
|
||||
top_graph_ = dstGraph;
|
||||
for (int i = 0; i < model_proto.functions_size(); ++i) {
|
||||
|
@ -1263,6 +1305,7 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
|||
}
|
||||
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
|
||||
}
|
||||
|
||||
// Release resource
|
||||
anfnode_build_map_.clear();
|
||||
return dstGraph;
|
||||
|
|
|
@ -35,7 +35,7 @@ class MSANFModelParser {
|
|||
~MSANFModelParser() = default;
|
||||
|
||||
static void LoadTensorMapClear() { load_tensor_map_.clear(); }
|
||||
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto);
|
||||
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto, const std::map<std::string, ValuePtr> &weights = {});
|
||||
bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto);
|
||||
|
||||
std::string GetProducerName() { return producer_name_; }
|
||||
|
@ -57,6 +57,7 @@ class MSANFModelParser {
|
|||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto);
|
||||
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
|
||||
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
|
||||
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
||||
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
|
||||
|
|
|
@ -229,7 +229,7 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
|
|||
if (is_lite_) {
|
||||
model_parser.SetLite();
|
||||
}
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CORE_LOAD_MODEL_H
|
||||
#define MINDSPORE_CORE_LOAD_MODEL_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -34,7 +35,9 @@ class MindIRLoader {
|
|||
|
||||
bool get_need_renormalize() const { return need_renormalize_; }
|
||||
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
|
||||
|
||||
void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) {
|
||||
weights_value_map_ = weights_value_map;
|
||||
}
|
||||
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
|
||||
|
@ -48,6 +51,7 @@ class MindIRLoader {
|
|||
std::string dec_mode_ = std::string("AES-GCM");
|
||||
bool inc_load_ = false;
|
||||
bool need_renormalize_ = true;
|
||||
std::map<string, ValuePtr> weights_value_map_;
|
||||
};
|
||||
|
||||
std::string LoadPreprocess(const std::string &file_name);
|
||||
|
|
|
@ -26,7 +26,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
|||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) {
|
||||
ModelProtoPtr empty_model;
|
||||
return empty_model;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue