diff --git a/mindspore/ccsrc/debug/dump_proto.h b/mindspore/ccsrc/debug/dump_proto.h index 7b5ee34ff67..b0806bdd4f7 100644 --- a/mindspore/ccsrc/debug/dump_proto.h +++ b/mindspore/ccsrc/debug/dump_proto.h @@ -32,7 +32,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); std::string GetBinaryProtoString(const FuncGraphPtr &func_graph); -ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); +ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph); void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index d9478d0295c..de418961f55 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -207,7 +207,7 @@ void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_grap MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno); return; } - ModelProtoPtr fg_model = GetBinaryProto(func_graph, false); + ModelProtoPtr fg_model = GetBinaryProto(func_graph); if (fg_model == nullptr) { MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed."; fout.close(); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 611997f6b8e..375de941525 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -110,7 +110,9 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_enable_tuple_broaden", &GraphExecutorPy::set_enable_tuple_broaden, py::arg("enable_tuple_broaden") = py::bool_(false), "Set tuple broaden enable.") .def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files, - py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files."); + py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.") + .def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(), + "Set values of weights."); (void)py::class_>(m, "EnvInstance_").def(py::init()); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index dbfa3d11771..985550e47a5 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -281,7 +281,17 @@ bool CheckDepFilesHashConsistency(const std::string ¤t_dep_files_hash) { return true; } -FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx) { +std::map GenerateWeightsValueMap(const py::dict &weights) { + std::map ret{}; + for (auto weight = weights.begin(); weight != weights.end(); ++weight) { + auto weight_name = py::cast(weight->first); + auto weight_value = parse::data_converter::PyDataToValue(py::cast(weight->second)); + ret[weight_name] = weight_value; + } + return ret; +} + +FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) { std::string compile_cache_path = GetCompileCachePath(idx); auto realpath = Common::CreatePrefixPath(compile_cache_path, true); if (!realpath.has_value()) { @@ -297,10 +307,11 @@ FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx) { } MindIRLoader mindir_loader; mindir_loader.set_need_renormalize(false); + mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights)); return mindir_loader.LoadMindIR(realpath.value()); } -FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_name) { +FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &weights, const std::string &queue_name) { MS_EXCEPTION_IF_NULL(resource); // Compare the dependency files hash. if (!CheckDepFilesHashConsistency(resource->compile_cache_dep_files_hash())) { @@ -309,7 +320,7 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const std::string & } // Load the compilation cache file. - FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id()); + FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id(), weights); if (fg == nullptr) { MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions."; return nullptr; @@ -352,7 +363,7 @@ bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, size_t idx) { MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno); return false; } - ModelProtoPtr fg_model = GetBinaryProto(fg, true); + ModelProtoPtr fg_model = GetBinaryProto(fg); if (fg_model == nullptr) { MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed."; fout.close(); @@ -920,7 +931,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple resource->set_enable_compile_cache(true); resource->set_compile_cache_id(GetCompileCacheGraphId()); resource->set_compile_cache_dep_files_hash(GetCompileDepFilesHash(compile_cache_dep_files_)); - resource->set_func_graph(GetCachedFuncGraph(resource, queue_name_)); + resource->set_func_graph(GetCachedFuncGraph(resource, weights_, queue_name_)); #ifdef ENABLE_PROFILE double t2 = GetTime(); MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 08e9f55e4ae..143b44e266f 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -110,6 +110,7 @@ class GraphExecutorPy : public std::enable_shared_from_this { void set_compile_cache_dep_files(const py::list &compile_cache_dep_files) { compile_cache_dep_files_ = compile_cache_dep_files; } + void set_weights_values(const py::dict &weights) { weights_ = weights; } #ifdef ENABLE_DEBUGGER static bool GetDebugTerminate() { return debugger_terminate_; } static void DebugTerminate(bool val, bool exit_success) { @@ -145,6 +146,7 @@ class GraphExecutorPy : public std::enable_shared_from_this { std::string queue_name_; bool enable_tuple_broaden_{false}; py::list compile_cache_dep_files_; + py::dict weights_; }; using GraphExecutorPyPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 43864196b80..d78a047e594 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -83,7 +83,7 @@ class IrExporter { explicit IrExporter(IrExportBuilderPtr builder) : builder_(std::move(builder)) {} virtual ~IrExporter() = default; std::string GetDumpString(const FuncGraphPtr &func_graph); - ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); + ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph); private: IrExportBuilderPtr builder_; @@ -95,13 +95,11 @@ class IrExportBuilder { ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } std::string GetProtoString() const; void BuildModelInfo(); - bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false); + bool BuildModel(const FuncGraphPtr &func_graph); ModelProtoPtr Model() { return model_; } - bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, - bool save_tensor_data = false); - bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, - bool save_tensor_data = false); + bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); + bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); @@ -148,6 +146,7 @@ class IrExportBuilder { std::set nodeName_; size_t node_index_{0}; size_t shape_index_{0}; + bool top_graph{true}; }; @@ -161,7 +160,7 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { return builder_->GetProtoString(); } -ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { +ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) { if ((builder_ == nullptr) || (func_graph == nullptr)) { MS_LOG(EXCEPTION) << "Input params is null."; } @@ -170,7 +169,7 @@ ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save builder_->BuildModelInfo(); // Export model and return string - if (!builder_->BuildModel(func_graph, save_tensor_data)) { + if (!builder_->BuildModel(func_graph)) { return nullptr; } return builder_->Model(); @@ -190,7 +189,7 @@ void IrExportBuilder::BuildModelInfo() { model_->set_little_endian(common::IsLittleByteOrder()); } -bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) { +bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); mind_ir::GraphProto *graph_proto = model_->mutable_graph(); graph_proto->set_name(func_graph->ToString()); @@ -201,7 +200,7 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso // Build the main funcGraph (void)nodeName_.insert(func_graph->ToString()); top_graph = true; - if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) { + if (!BuildFuncGraph(func_graph, graph_proto)) { MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed."; return false; } @@ -222,7 +221,7 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso (void)nodeName_.insert(fg->ToString()); (void)graphVisited.insert(fg); auto graph = model_->add_functions(); - if (!BuildFuncGraph(fg, graph, save_tensor_data)) { + if (!BuildFuncGraph(fg, graph)) { MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed."; return false; } @@ -233,14 +232,13 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso return true; } -bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, - bool save_tensor_data) { +bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { // Export funcGraph name. graph_proto->set_name(func_graph->ToString()); // Export parameters // 1. parameters should be mapped to ValueInfoProto // 2. parameters with default value should be mapped to Initializer - if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) { + if (!BuildParameters(func_graph, graph_proto)) { MS_LOG(ERROR) << "Build parameters failed."; return false; } @@ -249,8 +247,7 @@ bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::Gr return BuildNodes(func_graph, graph_proto); } -bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, - bool save_tensor_data) { +bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(graph_proto); for (auto &item : func_graph->parameters()) { @@ -269,10 +266,6 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed."; return false; } - auto tensor = std::dynamic_pointer_cast(param->default_param()); - if (tensor && save_tensor_data) { - parameter_proto->set_raw_data(tensor->data_c(), static_cast(tensor->data().nbytes())); - } } else { mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_name); @@ -1077,9 +1070,9 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return exporter->GetDumpString(func_graph); } -ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { +ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) { auto exporter = std::make_shared(std::make_shared()); - auto result = exporter->GetDumpProto(func_graph, save_tensor_data); + auto result = exporter->GetDumpProto(func_graph); return result; } } // namespace mindspore diff --git a/mindspore/common/api.py b/mindspore/common/api.py index dcd6b0232ac..411e4224ccf 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -274,6 +274,7 @@ class _MindsporeFunctionExecutor: if self.obj is None: is_compile = self._graph_executor.compile(self.fn, args_list, phase, True) else: + self._graph_executor.set_weights_values(self.obj.parameters_dict()) is_compile = self._graph_executor.compile(self.obj, args_list, phase, True) if not is_compile: raise RuntimeError("Executor compile failed.") @@ -651,7 +652,7 @@ class _CellGraphExecutor: enable_ge = context.get_context("enable_ge") use_vm = self._use_vm_mode() - + self._graph_executor.set_weights_values(obj.parameters_dict()) result = self._graph_executor.compile(obj, args_list, phase, use_vm) obj.compile_cache.add(phase) if not result: diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 0bc03063f8a..d092946e0ee 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -98,6 +98,7 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, // Default parameter can be shared since it is readonly. new_param->set_default_param(old_param->default_param()); } + new_param->set_is_top_graph_param(old_param->is_top_graph_param()); ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope(); new_param->set_scope(scope); repl_node_[node] = std::move(new_param); diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 9b1f440fd53..891ecb89a11 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -312,15 +312,17 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, MS_LOG(ERROR) << "Build parameter occur memcpy_s error."; return false; } - } else { + node->set_default_param(tensor_info); + } else if (parameter_proto.has_external_data()) { auto ret = GetTensorDataFromExternal(parameter_proto, tensor_info); if (!ret) { return false; } + node->set_default_param(tensor_info); + } else { + MS_LOG(DEBUG) << "Parameter will load initialized data."; } - node->set_default_param(tensor_info); - anfnode_build_map_[parameter_proto.name()] = node; return true; } @@ -974,7 +976,7 @@ bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) { // If the operator is not a primitive, the abstract will been set to null. // Because there are not some operators in front end, the abstract of primitive should be reserved. - if (prim == nullptr) { + if (prim == nullptr && need_renormalize()) { cnode_ptr->set_abstract(nullptr); return true; } @@ -1214,7 +1216,39 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m return true; } -FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { +bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, + const std::map &weights) { + size_t hyper_param_count = 0; + auto parameters = topGraph->parameters(); + for (int i = parameters.size() - 1; i >= 0; --i) { + auto parameter = parameters[i]->cast(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "AnfNode " << parameters[i]->DebugString() << " should be Parameter."; + return false; + } + auto type = parameter->Type(); + if (type == nullptr) { + MS_LOG(ERROR) << "Parameter " << parameter->DebugString() << " has no type."; + return false; + } + if (!type->isa()) { + break; + } + auto parameter_name = parameter->name(); + auto weights_iter = weights.find(parameter_name); + if (weights_iter == weights.end()) { + MS_LOG(ERROR) << "Find initial weight value for " << parameter_name << " failed."; + return false; + } + parameter->set_default_param(weights_iter->second); + hyper_param_count++; + } + topGraph->set_hyper_param_count(hyper_param_count); + return true; +} + +FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto, + const std::map &weights) { FuncGraphPtr dstGraph = std::make_shared(); if (!MSANFParseModelConfigureInfo(model_proto)) { MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; @@ -1252,6 +1286,14 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { MS_LOG(ERROR) << "Build funcgraph failed!"; return nullptr; } + + if (!weights.empty()) { + if (!SetValueForTopGraphParameter(dstGraph, weights)) { + MS_LOG(ERROR) << "Set value for top graph fail."; + return nullptr; + } + } + MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name(); top_graph_ = dstGraph; for (int i = 0; i < model_proto.functions_size(); ++i) { @@ -1263,6 +1305,7 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { } MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name(); } + // Release resource anfnode_build_map_.clear(); return dstGraph; diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index 2f098f4c030..ac317155a5f 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -35,7 +35,7 @@ class MSANFModelParser { ~MSANFModelParser() = default; static void LoadTensorMapClear() { load_tensor_map_.clear(); } - FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto); + FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto, const std::map &weights = {}); bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto); std::string GetProducerName() { return producer_name_; } @@ -57,6 +57,7 @@ class MSANFModelParser { bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto); + bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map &weights); bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info); bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto); tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto); diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index a6f9da62037..f5929e78a3d 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -229,7 +229,7 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) { if (is_lite_) { model_parser.SetLite(); } - FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model); + FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_); return dstgraph_ptr; } diff --git a/mindspore/core/load_mindir/load_model.h b/mindspore/core/load_mindir/load_model.h index b36c46ee3e4..64a8affc697 100644 --- a/mindspore/core/load_mindir/load_model.h +++ b/mindspore/core/load_mindir/load_model.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CORE_LOAD_MODEL_H #define MINDSPORE_CORE_LOAD_MODEL_H +#include #include #include #include @@ -34,7 +35,9 @@ class MindIRLoader { bool get_need_renormalize() const { return need_renormalize_; } void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; } - + void set_weights_value_map(const std::map &weights_value_map) { + weights_value_map_ = weights_value_map; + } FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size); FuncGraphPtr LoadMindIR(const std::string &file_name); std::vector LoadMindIRs(const std::vector file_names); @@ -48,6 +51,7 @@ class MindIRLoader { std::string dec_mode_ = std::string("AES-GCM"); bool inc_load_ = false; bool need_renormalize_ = true; + std::map weights_value_map_; }; std::string LoadPreprocess(const std::string &file_name); diff --git a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc index 64ab105d04d..95796bd172c 100644 --- a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc +++ b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc @@ -26,7 +26,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; } -ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { +ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) { ModelProtoPtr empty_model; return empty_model; }