Enable mindir to load initialize weight from python

This commit is contained in:
l00591931 2021-11-09 15:25:19 +08:00
parent 821aba8907
commit 21df240f23
13 changed files with 98 additions and 40 deletions

View File

@ -32,7 +32,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);

View File

@ -207,7 +207,7 @@ void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_grap
MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno);
return;
}
ModelProtoPtr fg_model = GetBinaryProto(func_graph, false);
ModelProtoPtr fg_model = GetBinaryProto(func_graph);
if (fg_model == nullptr) {
MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
fout.close();

View File

@ -110,7 +110,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_enable_tuple_broaden", &GraphExecutorPy::set_enable_tuple_broaden,
py::arg("enable_tuple_broaden") = py::bool_(false), "Set tuple broaden enable.")
.def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files,
py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.");
py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.")
.def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(),
"Set values of weights.");
(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());

View File

@ -281,7 +281,17 @@ bool CheckDepFilesHashConsistency(const std::string &current_dep_files_hash) {
return true;
}
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx) {
std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
std::map<string, ValuePtr> ret{};
for (auto weight = weights.begin(); weight != weights.end(); ++weight) {
auto weight_name = py::cast<std::string>(weight->first);
auto weight_value = parse::data_converter::PyDataToValue(py::cast<py::object>(weight->second));
ret[weight_name] = weight_value;
}
return ret;
}
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) {
std::string compile_cache_path = GetCompileCachePath(idx);
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
if (!realpath.has_value()) {
@ -297,10 +307,11 @@ FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx) {
}
MindIRLoader mindir_loader;
mindir_loader.set_need_renormalize(false);
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
return mindir_loader.LoadMindIR(realpath.value());
}
FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_name) {
FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &weights, const std::string &queue_name) {
MS_EXCEPTION_IF_NULL(resource);
// Compare the dependency files hash.
if (!CheckDepFilesHashConsistency(resource->compile_cache_dep_files_hash())) {
@ -309,7 +320,7 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const std::string &
}
// Load the compilation cache file.
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id());
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource->compile_cache_id(), weights);
if (fg == nullptr) {
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
return nullptr;
@ -352,7 +363,7 @@ bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, size_t idx) {
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
return false;
}
ModelProtoPtr fg_model = GetBinaryProto(fg, true);
ModelProtoPtr fg_model = GetBinaryProto(fg);
if (fg_model == nullptr) {
MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
fout.close();
@ -920,7 +931,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
resource->set_enable_compile_cache(true);
resource->set_compile_cache_id(GetCompileCacheGraphId());
resource->set_compile_cache_dep_files_hash(GetCompileDepFilesHash(compile_cache_dep_files_));
resource->set_func_graph(GetCachedFuncGraph(resource, queue_name_));
resource->set_func_graph(GetCachedFuncGraph(resource, weights_, queue_name_));
#ifdef ENABLE_PROFILE
double t2 = GetTime();
MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1);

View File

@ -110,6 +110,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
void set_compile_cache_dep_files(const py::list &compile_cache_dep_files) {
compile_cache_dep_files_ = compile_cache_dep_files;
}
void set_weights_values(const py::dict &weights) { weights_ = weights; }
#ifdef ENABLE_DEBUGGER
static bool GetDebugTerminate() { return debugger_terminate_; }
static void DebugTerminate(bool val, bool exit_success) {
@ -145,6 +146,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
std::string queue_name_;
bool enable_tuple_broaden_{false};
py::list compile_cache_dep_files_;
py::dict weights_;
};
using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;

View File

@ -83,7 +83,7 @@ class IrExporter {
explicit IrExporter(IrExportBuilderPtr builder) : builder_(std::move(builder)) {}
virtual ~IrExporter() = default;
std::string GetDumpString(const FuncGraphPtr &func_graph);
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph);
private:
IrExportBuilderPtr builder_;
@ -95,13 +95,11 @@ class IrExportBuilder {
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
std::string GetProtoString() const;
void BuildModelInfo();
bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
bool BuildModel(const FuncGraphPtr &func_graph);
ModelProtoPtr Model() { return model_; }
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data = false);
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data = false);
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
@ -148,6 +146,7 @@ class IrExportBuilder {
std::set<std::string> nodeName_;
size_t node_index_{0};
size_t shape_index_{0};
bool top_graph{true};
};
@ -161,7 +160,7 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
return builder_->GetProtoString();
}
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
if ((builder_ == nullptr) || (func_graph == nullptr)) {
MS_LOG(EXCEPTION) << "Input params is null.";
}
@ -170,7 +169,7 @@ ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save
builder_->BuildModelInfo();
// Export model and return string
if (!builder_->BuildModel(func_graph, save_tensor_data)) {
if (!builder_->BuildModel(func_graph)) {
return nullptr;
}
return builder_->Model();
@ -190,7 +189,7 @@ void IrExportBuilder::BuildModelInfo() {
model_->set_little_endian(common::IsLittleByteOrder());
}
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
mind_ir::GraphProto *graph_proto = model_->mutable_graph();
graph_proto->set_name(func_graph->ToString());
@ -201,7 +200,7 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
// Build the main funcGraph
(void)nodeName_.insert(func_graph->ToString());
top_graph = true;
if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) {
if (!BuildFuncGraph(func_graph, graph_proto)) {
MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed.";
return false;
}
@ -222,7 +221,7 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
(void)nodeName_.insert(fg->ToString());
(void)graphVisited.insert(fg);
auto graph = model_->add_functions();
if (!BuildFuncGraph(fg, graph, save_tensor_data)) {
if (!BuildFuncGraph(fg, graph)) {
MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed.";
return false;
}
@ -233,14 +232,13 @@ bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
return true;
}
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) {
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
// Export funcGraph name.
graph_proto->set_name(func_graph->ToString());
// Export parameters
// 1. parameters should be mapped to ValueInfoProto
// 2. parameters with default value should be mapped to Initializer
if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) {
if (!BuildParameters(func_graph, graph_proto)) {
MS_LOG(ERROR) << "Build parameters failed.";
return false;
}
@ -249,8 +247,7 @@ bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::Gr
return BuildNodes(func_graph, graph_proto);
}
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) {
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_proto);
for (auto &item : func_graph->parameters()) {
@ -269,10 +266,6 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
return false;
}
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
if (tensor && save_tensor_data) {
parameter_proto->set_raw_data(tensor->data_c(), static_cast<size_t>(tensor->data().nbytes()));
}
} else {
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_name);
@ -1077,9 +1070,9 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
return exporter->GetDumpString(func_graph);
}
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) {
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
auto result = exporter->GetDumpProto(func_graph);
return result;
}
} // namespace mindspore

View File

@ -274,6 +274,7 @@ class _MindsporeFunctionExecutor:
if self.obj is None:
is_compile = self._graph_executor.compile(self.fn, args_list, phase, True)
else:
self._graph_executor.set_weights_values(self.obj.parameters_dict())
is_compile = self._graph_executor.compile(self.obj, args_list, phase, True)
if not is_compile:
raise RuntimeError("Executor compile failed.")
@ -651,7 +652,7 @@ class _CellGraphExecutor:
enable_ge = context.get_context("enable_ge")
use_vm = self._use_vm_mode()
self._graph_executor.set_weights_values(obj.parameters_dict())
result = self._graph_executor.compile(obj, args_list, phase, use_vm)
obj.compile_cache.add(phase)
if not result:

View File

@ -98,6 +98,7 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
// Default parameter can be shared since it is readonly.
new_param->set_default_param(old_param->default_param());
}
new_param->set_is_top_graph_param(old_param->is_top_graph_param());
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_param->set_scope(scope);
repl_node_[node] = std::move(new_param);

View File

@ -312,15 +312,17 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
return false;
}
} else {
node->set_default_param(tensor_info);
} else if (parameter_proto.has_external_data()) {
auto ret = GetTensorDataFromExternal(parameter_proto, tensor_info);
if (!ret) {
return false;
}
node->set_default_param(tensor_info);
} else {
MS_LOG(DEBUG) << "Parameter will load initialized data.";
}
node->set_default_param(tensor_info);
anfnode_build_map_[parameter_proto.name()] = node;
return true;
}
@ -974,7 +976,7 @@ bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) {
// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
if (prim == nullptr && need_renormalize()) {
cnode_ptr->set_abstract(nullptr);
return true;
}
@ -1214,7 +1216,39 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m
return true;
}
FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
const std::map<std::string, ValuePtr> &weights) {
size_t hyper_param_count = 0;
auto parameters = topGraph->parameters();
for (int i = parameters.size() - 1; i >= 0; --i) {
auto parameter = parameters[i]->cast<ParameterPtr>();
if (parameter == nullptr) {
MS_LOG(ERROR) << "AnfNode " << parameters[i]->DebugString() << " should be Parameter.";
return false;
}
auto type = parameter->Type();
if (type == nullptr) {
MS_LOG(ERROR) << "Parameter " << parameter->DebugString() << " has no type.";
return false;
}
if (!type->isa<RefType>()) {
break;
}
auto parameter_name = parameter->name();
auto weights_iter = weights.find(parameter_name);
if (weights_iter == weights.end()) {
MS_LOG(ERROR) << "Find initial weight value for " << parameter_name << " failed.";
return false;
}
parameter->set_default_param(weights_iter->second);
hyper_param_count++;
}
topGraph->set_hyper_param_count(hyper_param_count);
return true;
}
FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
const std::map<std::string, ValuePtr> &weights) {
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
if (!MSANFParseModelConfigureInfo(model_proto)) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
@ -1252,6 +1286,14 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr;
}
if (!weights.empty()) {
if (!SetValueForTopGraphParameter(dstGraph, weights)) {
MS_LOG(ERROR) << "Set value for top graph fail.";
return nullptr;
}
}
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
top_graph_ = dstGraph;
for (int i = 0; i < model_proto.functions_size(); ++i) {
@ -1263,6 +1305,7 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
}
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
}
// Release resource
anfnode_build_map_.clear();
return dstGraph;

View File

@ -35,7 +35,7 @@ class MSANFModelParser {
~MSANFModelParser() = default;
static void LoadTensorMapClear() { load_tensor_map_.clear(); }
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto);
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto, const std::map<std::string, ValuePtr> &weights = {});
bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto);
std::string GetProducerName() { return producer_name_; }
@ -57,6 +57,7 @@ class MSANFModelParser {
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto);
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);

View File

@ -229,7 +229,7 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
if (is_lite_) {
model_parser.SetLite();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_);
return dstgraph_ptr;
}

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CORE_LOAD_MODEL_H
#define MINDSPORE_CORE_LOAD_MODEL_H
#include <map>
#include <vector>
#include <string>
#include <memory>
@ -34,7 +35,9 @@ class MindIRLoader {
bool get_need_renormalize() const { return need_renormalize_; }
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) {
weights_value_map_ = weights_value_map;
}
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
FuncGraphPtr LoadMindIR(const std::string &file_name);
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
@ -48,6 +51,7 @@ class MindIRLoader {
std::string dec_mode_ = std::string("AES-GCM");
bool inc_load_ = false;
bool need_renormalize_ = true;
std::map<string, ValuePtr> weights_value_map_;
};
std::string LoadPreprocess(const std::string &file_name);

View File

@ -26,7 +26,7 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph) {
ModelProtoPtr empty_model;
return empty_model;
}