From d384db6c01f20e3b21cbe921162fa924f5e75e0f Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Sat, 9 Oct 2021 11:02:37 +0800 Subject: [PATCH] Continue execution when saving and loading mindir failed --- mindspore/ccsrc/debug/dump_proto.h | 4 +- .../ccsrc/frontend/optimizer/ad/kprim.cc | 19 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 22 +- .../transform/express_ir/mindir_exporter.cc | 370 ++++++++++++------ .../core/load_mindir/anf_model_parser.cc | 35 +- tests/st/compile_cache/test_ascend_lenet.py | 25 ++ tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc | 4 +- 7 files changed, 342 insertions(+), 137 deletions(-) diff --git a/mindspore/ccsrc/debug/dump_proto.h b/mindspore/ccsrc/debug/dump_proto.h index 2e035c21376..79e2ce69a5c 100644 --- a/mindspore/ccsrc/debug/dump_proto.h +++ b/mindspore/ccsrc/debug/dump_proto.h @@ -17,19 +17,21 @@ #define MINDSPORE_CCSRC_DEBUG_DUMP_PROTO_H_ #include +#include #include "ir/func_graph.h" #include "proto/mind_ir.pb.h" #include "debug/common.h" namespace mindspore { +using ModelProtoPtr = std::shared_ptr; std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); std::string GetBinaryProtoString(const FuncGraphPtr &func_graph); -mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); +ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 5ebae2e1baa..12a60ce6e2a 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -158,10 +158,21 @@ void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_grap return; } std::ofstream fout(bprop_mindir_realpath.value()); - mind_ir::ModelProto fg_model = GetBinaryProto(func_graph, false); - if (!fg_model.SerializeToOstream(&fout)) { - MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \"" - << bprop_mindir_realpath.value() << "\"."; + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno); + return; + } + ModelProtoPtr fg_model = GetBinaryProto(func_graph, false); + if (fg_model == nullptr) { + MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed."; + fout.close(); + return; + } + if (!fg_model->SerializeToOstream(&fout)) { + MS_LOG(ERROR) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \"" + << bprop_mindir_realpath.value() << "\"."; + fout.close(); + return; } fout.close(); ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 28f12c21b42..00921fb5d3e 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -189,7 +189,8 @@ void GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_na MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only."; FuncGraphPtr fg = mindspore::LoadMindIR(realpath.value()); if (fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value(); + MS_LOG(ERROR) << "Failed to load the compilation cache file: " << realpath.value(); + return; } FuncGraphManagerPtr mng = fg->manager(); if (mng == nullptr) { @@ -213,18 +214,27 @@ void CacheFuncGraph(const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(resource); auto realpath = Common::CreatePrefixPath(kCompileCacheFilePath); if (!realpath.has_value()) { - MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath; + MS_LOG(ERROR) << "Get real path failed. filename=" << kCompileCacheFilePath; + return; } ChangeFileMode(realpath.value(), S_IRWXU); std::ofstream fout(realpath.value()); if (!fout.is_open()) { - MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno); + MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno); + return; } FuncGraphPtr fg = resource->func_graph(); - mind_ir::ModelProto fg_model = GetBinaryProto(fg, true); - if (!fg_model.SerializeToOstream(&fout)) { - MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value(); + ModelProtoPtr fg_model = GetBinaryProto(fg, true); + if (fg_model == nullptr) { + MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed."; + fout.close(); + return; + } + if (!fg_model->SerializeToOstream(&fout)) { + MS_LOG(ERROR) << "Failed to cache the graph to file " << realpath.value(); + fout.close(); + return; } fout.close(); ChangeFileMode(realpath.value(), S_IRUSR); diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index ef8981cb153..0281bf46c7c 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -27,6 +27,7 @@ #include "base/core_ops.h" #include "proto/mind_ir.pb.h" #include "utils/check_convert_utils.h" +#include "debug/dump_proto.h" namespace mindspore { using FloatPtr = std::shared_ptr; @@ -78,7 +79,7 @@ class IrExporter { explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} virtual ~IrExporter() = default; std::string GetDumpString(const FuncGraphPtr &func_graph); - mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); + ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); private: IrExportBuilderPtr builder_; @@ -86,39 +87,38 @@ class IrExporter { class IrExportBuilder { public: - IrExportBuilder() = default; + IrExportBuilder() : model_(std::make_shared()) {} ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } std::string GetProtoString(const FuncGraphPtr &func_graph); void BuildModelInfo(); - void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false); - mind_ir::ModelProto Model() { return model_; } + bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false); + ModelProtoPtr Model() { return model_; } private: - void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, + bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, bool save_tensor_data = false); - void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, + bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, bool save_tensor_data = false); - void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); - void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); - void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); + bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); + bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); + bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto); - void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); - void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto); - void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); - void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); - void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); - void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, + bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); + bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); + bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); + bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); + bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); + bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); - void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); - void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); - void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); - void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); - void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, + bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); + bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); - void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, + bool SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id); @@ -134,7 +134,7 @@ class IrExportBuilder { void ResetTupleIndex() { shape_index_ = 0; } private: - mind_ir::ModelProto model_; + ModelProtoPtr model_; mind_ir::NodeProto *last_node_{nullptr}; std::list todo_; std::map node_index_map_; @@ -147,11 +147,14 @@ class IrExportBuilder { using IrExporterPtr = std::shared_ptr; std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { - (void)GetDumpProto(func_graph); + auto dump_proto = GetDumpProto(func_graph); + if (dump_proto == nullptr) { + MS_LOG(EXCEPTION) << "Get dump proto for graph " << func_graph->ToString() << " failed."; + } return builder_->GetProtoString(func_graph); } -mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { +ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { if ((builder_ == nullptr) || (func_graph == nullptr)) { MS_LOG(EXCEPTION) << "Input params is null."; } @@ -160,26 +163,28 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo builder_->BuildModelInfo(); // Export model and return string - builder_->BuildModel(func_graph, save_tensor_data); + if (!builder_->BuildModel(func_graph, save_tensor_data)) { + return nullptr; + } return builder_->Model(); } std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { MS_LOG(DEBUG) << "BuildModel complete!"; - return model_.SerializeAsString(); + return model_->SerializeAsString(); } void IrExportBuilder::BuildModelInfo() { constexpr auto ir_version = "0.1.0"; constexpr auto mindspore_name = "MindSpore"; - model_.set_ir_version(ir_version); - model_.set_producer_name(mindspore_name); - model_.set_model_version(VERSION); + model_->set_ir_version(ir_version); + model_->set_producer_name(mindspore_name); + model_->set_model_version(VERSION); } -void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) { +bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) { MS_EXCEPTION_IF_NULL(func_graph); - mind_ir::GraphProto *graph_proto = model_.mutable_graph(); + mind_ir::GraphProto *graph_proto = model_->mutable_graph(); graph_proto->set_name(func_graph->ToString()); graph_proto->set_bprop_hash(func_graph->bprop_hash()); ResetNodeIndex(); @@ -188,7 +193,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso // Build the main funcGraph nodeName_.insert(func_graph->ToString()); top_graph = true; - BuildFuncGraph(func_graph, graph_proto, save_tensor_data); + if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) { + MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed."; + return false; + } + std::set graphVisited; graphVisited.insert(func_graph); top_graph = false; @@ -199,32 +208,40 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso continue; } if (nodeName_.count(fg->ToString()) > 0) { - MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString(); + MS_LOG(ERROR) << "There is a duplicate name: " << fg->ToString(); + return false; } nodeName_.insert(fg->ToString()); graphVisited.insert(fg); - auto graph = model_.add_functions(); - BuildFuncGraph(fg, graph, save_tensor_data); + auto graph = model_->add_functions(); + if (!BuildFuncGraph(fg, graph, save_tensor_data)) { + MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed."; + return false; + } } // Release resource nodeName_.clear(); node_index_map_.clear(); + return true; } -void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, +bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, bool save_tensor_data) { // Export funcGraph name. graph_proto->set_name(func_graph->ToString()); // Export parameters // 1. parameters should be mapped to ValueInfoProto // 2. parameters with default value should be mapped to Initializer - BuildParameters(func_graph, graph_proto, save_tensor_data); + if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) { + MS_LOG(ERROR) << "Build parameters failed."; + return false; + } // Export operator nodes(include output) - BuildNodes(func_graph, graph_proto); + return BuildNodes(func_graph, graph_proto); } -void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, +bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, bool save_tensor_data) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(graph_proto); @@ -232,14 +249,18 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G MS_EXCEPTION_IF_NULL(item); auto param = item->cast(); if (param == nullptr) { - MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; + MS_LOG(ERROR) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; + return false; } std::string param_name = GetUniqueNodeName(param); if (top_graph && param->has_default()) { MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get(); mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter(); parameter_proto->set_name(param_name); - SetParamToTensorProto(param, parameter_proto); + if (!SetParamToTensorProto(param, parameter_proto)) { + MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed."; + return false; + } auto tensor = std::dynamic_pointer_cast(param->default_param()); if (tensor && save_tensor_data) { parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); @@ -247,19 +268,25 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G } else { mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_name); - SetValueInfoProto(param, input_proto); + if (!SetValueInfoProto(param, input_proto)) { + MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed."; + return false; + } } if (nodeName_.count(param_name) > 0) { - MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name; + MS_LOG(ERROR) << "parameter name is duplicate:" << param_name; + return false; } nodeName_.insert(param_name); } + return true; } mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) { auto iter = g_data_type_map.find(type_id); if (iter == g_data_type_map.end()) { - MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; + MS_LOG(ERROR) << "Convert type error, unsupported type! " << type_id; + return mind_ir::TensorProto_DataType_UNDEFINED; } return iter->second; } @@ -267,7 +294,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) { auto iter = g_data_bits_int_map.find(bits); if (iter == g_data_bits_int_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; + MS_LOG(ERROR) << "Convert bits int error, unsupported bits! " << bits; + return mind_ir::TensorProto_DataType_UNDEFINED; } return iter->second; } @@ -275,7 +303,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) { auto iter = g_data_bits_uint_map.find(bits); if (iter == g_data_bits_uint_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits; + MS_LOG(ERROR) << "Convert bits uint error, unsupported bits! " << bits; + return mind_ir::TensorProto_DataType_UNDEFINED; } return iter->second; } @@ -283,20 +312,22 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bit mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) { auto iter = g_data_bits_float_map.find(bits); if (iter == g_data_bits_float_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; + MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits; + return mind_ir::TensorProto_DataType_UNDEFINED; } return iter->second; } -void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) { +bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) { if (node == nullptr || value_proto == nullptr) { MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; } MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); const TypePtr &type = node->Type(); const BaseShapePtr &shape = node->Shape(); + // For the bprop fg which has not been renormalized. if (type == nullptr || shape == nullptr) { - return; + return true; } if (type->isa() && shape->isa()) { auto tensor = type->cast(); @@ -304,7 +335,11 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn auto elem_type = tensor->element(); const auto &dims = shape->cast()->shape(); mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); - tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id())); + auto data_type = GetMindirDataType(elem_type->type_id()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); if (dims.size() == 0) { MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0."; } else { @@ -320,9 +355,10 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn value_proto->set_denotation(type->type_name()); } MS_LOG(DEBUG) << "Value type: " << type->type_name(); + return true; } -void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { +bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } @@ -335,34 +371,45 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir:: tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); auto dtype = data->data_type(); auto shape = data->shape_c(); - tensor_proto->set_data_type(GetMindirDataType(dtype)); + auto data_type = GetMindirDataType(dtype); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); for (const auto &dim : shape) { tensor_proto->add_dims(dim); } + return true; } -void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, +bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto) { if (!type->isa() || !shape->isa()) { - MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); + MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString(); + return false; } auto tensor = type->cast(); const auto &dims = shape->cast()->shape(); - tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id())); + auto data_type = GetMindirDataType(tensor->element()->type_id()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); for (const auto &dim : dims) { tensor_proto->add_dims(dim); } + return true; } -void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { +bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { if (param == nullptr || tensor_proto == nullptr) { MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; } MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); - SetTensorProto(param->Type(), param->Shape(), tensor_proto); + return SetTensorProto(param->Type(), param->Shape(), tensor_proto); } -void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { +bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); @@ -372,24 +419,36 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP } auto cnode = node->cast(); if (cnode == func_graph->get_return()) { - BuildOutput(cnode, graph_proto); + if (!BuildOutput(cnode, graph_proto)) { + MS_LOG(ERROR) << "Build output for graph " << func_graph->ToString() << " failed."; + return false; + } } else { - BuildCNode(cnode, graph_proto); + if (!BuildCNode(cnode, graph_proto)) { + MS_LOG(ERROR) << "Build proto for cnode " << cnode->DebugString() << " failed."; + return false; + } } } + return true; } -void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { +bool IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { MS_EXCEPTION_IF_NULL(node); const int OutputSize = 2; if (node->size() != OutputSize) { - MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; + MS_LOG(ERROR) << "Number of inputs of return node is not equal to 2."; + return false; } AnfNodePtr arg = node->input(1); std::string node_name = BuildInputNode(arg, graph_proto); + if (node_name.empty()) { + MS_LOG(ERROR) << "Build input node failed for arg " << arg->DebugString(); + return false; + } mind_ir::ValueInfoProto *output_proto = graph_proto->add_output(); output_proto->set_name(node_name); - SetValueInfoProto(arg, output_proto); + return SetValueInfoProto(arg, output_proto); } std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { @@ -408,16 +467,18 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { auto nodeName = GetUniqueNodeName(node); type_name = "REF::" + nodeName; if (nodeName_.count(nodeName) == 0) { - MS_LOG(EXCEPTION) << "There is not the name: " << nodeName; + MS_LOG(ERROR) << "There is not the name: " << nodeName; + return ""; } } else { - MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); + MS_LOG(ERROR) << "Need to support op type: " << node->type_name(); + return ""; } MS_LOG(DEBUG) << "ExportType: " << type_name; return type_name; } -void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, +bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(shape); @@ -427,7 +488,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt auto elements = type->cast()->elements(); auto tuple_shape = shape->cast()->shape(); for (size_t i = 0; i < elements.size(); i++) { - SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); + if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) { + return false; + } } *seq_string += "],"; } else if (type->isa() && shape->isa()) { @@ -435,7 +498,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt *seq_string += shape_name + ","; mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); tensor_proto->set_name(shape_name); - SetTensorProto(type, shape, tensor_proto); + return SetTensorProto(type, shape, tensor_proto); } else if (type->isa()) { if (type->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); @@ -453,11 +516,13 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt } else if (type->isa() || type->isa() || type->isa()) { *seq_string += type->type_name() + ","; } else { - MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); + MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name(); + return false; } + return true; } -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { +bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { // Get shape of cnode // 1. need to get shape from tuple element // 2. save shape in TensorProto @@ -465,21 +530,27 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro MS_EXCEPTION_IF_NULL(node); auto type = node->Type(); auto shape = node->Shape(); + // For the bprop fg which has not been renormalized. if (type == nullptr || shape == nullptr) { - return; + return true; } ResetTupleIndex(); std::string seq_string = "shape:"; mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); - SetShapeToNodeProto(type, shape, attr_proto, &seq_string); + if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) { + MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed."; + return false; + } attr_proto->set_ref_attr_name(seq_string); MS_LOG(DEBUG) << "CNode shape: " << seq_string; + return true; } -void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { +bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { auto inputs_size = node->size(); if (inputs_size < 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + MS_LOG(ERROR) << "Inputs of node " << node->DebugString() << " is empty"; + return false; } // Need to build input node before dealing with cnode @@ -488,7 +559,12 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons for (size_t i = 1; i < inputs_size; i++) { auto input = node->input(i); op_inputs.push_back(input); - input_names.push_back(BuildInputNode(input, graph_proto)); + std::string node_name = BuildInputNode(input, graph_proto); + if (node_name.empty()) { + MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed."; + return false; + } + input_names.push_back(node_name); } // Build cnode @@ -503,10 +579,16 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons node_proto->set_domain(node->fullname_with_scope()); AnfNodePtr op = node->input(0); std::string type_name = GetOpTypeName(op); + if (type_name.empty()) { + MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed."; + return false; + } node_proto->set_op_type(type_name); last_node_ = node_proto; // Maybe Tensor or Function or nullptr - SetShapeToNodeProto(node, node_proto); + if (!SetShapeToNodeProto(node, node_proto)) { + return false; + } (void)std::for_each(input_names.begin(), input_names.end(), [&node_proto](const string &name) { node_proto->add_input(name); }); @@ -520,9 +602,13 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons attr_proto->set_name(attr.first); auto attr_value = attr.second; CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value); - SetValueToAttributeProto(attr_value, attr_proto); + if (!SetValueToAttributeProto(attr_value, attr_proto)) { + MS_LOG(ERROR) << "Set value to AttributeProto failed."; + return false; + } } } + return true; } std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) { @@ -538,7 +624,9 @@ std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::Gra mind_ir::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_name(node_name); node_proto->add_output(node_name); - SetAttributeProto(node, node_proto); + if (!SetAttributeProto(node, node_proto)) { + return ""; + } } return node_name; } @@ -581,7 +669,7 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { return node_name; } -void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) { +bool IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) { if (node == nullptr || node_proto == nullptr) { MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; } @@ -592,10 +680,10 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); - SetValueToAttributeProto(value, attr_proto); + return SetValueToAttributeProto(value, attr_proto); } -void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { +bool IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } @@ -605,17 +693,29 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At attr_proto->set_ref_attr_name("type:value0"); tensor_proto->set_name("value0"); auto int_value = value->cast(); - tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); + auto data_type = GetMindirDataBitsIntType(int_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else if (value->isa()) { attr_proto->set_ref_attr_name("type:value0"); tensor_proto->set_name("value0"); auto float_value = value->cast(); - tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits())); + auto data_type = GetMindirDataBitsUIntType(float_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else if (value->isa()) { attr_proto->set_ref_attr_name("type:value0"); tensor_proto->set_name("value0"); auto float_value = value->cast(); - tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); + auto data_type = GetMindirDataBitsFloatType(float_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else if (value->isa()) { attr_proto->set_ref_attr_name("type:value0"); tensor_proto->set_name("value0"); @@ -626,35 +726,49 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At auto elem_type = value->cast()->element(); if (elem_type->isa()) { auto int_value = elem_type->cast(); - tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); + auto data_type = GetMindirDataBitsIntType(int_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else if (elem_type->isa()) { auto float_value = elem_type->cast(); - tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); + auto data_type = GetMindirDataBitsFloatType(float_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else { - MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); + MS_LOG(ERROR) << "Unsupported type " << elem_type->type_name(); + return false; } } else { MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + return false; } + return true; } -void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { +bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } if (value->isa() || value->isa()) { - SetScalarToAttributeProto_ir(value, attr_proto); + return SetScalarToAttributeProto_ir(value, attr_proto); } else if (value->isa() || value->isa()) { - SetTypeToAttributeProto(value, attr_proto); + return SetTypeToAttributeProto(value, attr_proto); } else if (value->isa()) { ResetTupleIndex(); std::string seq_string = "scalar:"; attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); - SetSequenceToAttributeProto(value->cast(), attr_proto, &seq_string); + if (!SetSequenceToAttributeProto(value->cast(), attr_proto, &seq_string)) { + MS_LOG(ERROR) << "Set sequence to AttributeProto failed."; + return false; + } attr_proto->set_ref_attr_name(seq_string); MS_LOG(DEBUG) << "Attr string: " << seq_string; } else if (value->isa()) { - SetTensorToAttributeProto(value, attr_proto); + return SetTensorToAttributeProto(value, attr_proto); } else if (value->isa()) { attr_proto->set_ref_attr_name("none"); MS_LOG(DEBUG) << "Attr string: " << value->type_name(); @@ -664,14 +778,17 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A } else if (value->isa()) { attr_proto->set_ref_attr_name("Monad:IOMonad"); } else { - MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name(); + MS_LOG(ERROR) << "Unsupported Monad type: " << value->type_name(); + return false; } } else { - MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + MS_LOG(ERROR) << "Unsupported type: " << value->type_name(); + return false; } + return true; } -void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { +bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } @@ -714,13 +831,15 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i attr_proto->set_d(GetValue(value)); } else if (value->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); - SetTensorToAttributeProto(value, attr_proto); + return SetTensorToAttributeProto(value, attr_proto); } else { - MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); + MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); + return false; } + return true; } -void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { +bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } @@ -728,12 +847,20 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); auto int_value = value->cast(); - tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); + auto data_type = GetMindirDataBitsIntType(int_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else if (value->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); auto float_value = value->cast(); - tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); + auto data_type = GetMindirDataBitsFloatType(float_value->nbits()); + if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { + return false; + } + tensor_proto->set_data_type(data_type); } else if (value->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); attr_proto->add_strings(GetValue(value)); @@ -772,22 +899,24 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ attr_proto->add_doubles(GetValue(value)); } else if (value->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); - SetTensorToAttributeProto(value, attr_proto); + return SetTensorToAttributeProto(value, attr_proto); } else { - MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); + MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); + return false; } + return true; } -void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, +bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { string value_name = "value" + std::to_string(GetTupleIndex()); if (seq_string != nullptr) { *seq_string += value_name + ","; } - SetScalarToAttributeProto_irs(value, attr_proto); + return SetScalarToAttributeProto_irs(value, attr_proto); } -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, +bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { if (value == nullptr || attr_proto == nullptr) { @@ -799,13 +928,19 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, if (tuple_value->value().size() == 0) { *seq_string += "],"; MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; - return; + return true; } for (const auto &item : tuple_value->value()) { if (item->isa()) { - SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); + if (!SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string)) { + MS_LOG(ERROR) << "Set sequence to AttributeProto failed."; + return false; + } } else { - SetSeqElemToAttributeProto(item, attr_proto, seq_string); + if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) { + MS_LOG(ERROR) << "Set seq elem to AttributeProto failed."; + return false; + } } } *seq_string += "],"; @@ -815,18 +950,25 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, if (list_value->value().size() == 0) { *seq_string += "],"; MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; - return; + return true; } for (const auto &item : list_value->value()) { MS_EXCEPTION_IF_NULL(item); if (item->isa()) { - SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); + if (!SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string)) { + MS_LOG(ERROR) << "Set sequence to AttributeProto failed."; + return false; + } } else { - SetSeqElemToAttributeProto(item, attr_proto, seq_string); + if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) { + MS_LOG(ERROR) << "Set seq elem to AttributeProto failed."; + return false; + } } } *seq_string += "],"; } + return true; } std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { @@ -842,7 +984,7 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return exporter->GetDumpString(func_graph); } -mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { +ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { auto exporter = std::make_shared(std::make_shared()); auto result = exporter->GetDumpProto(func_graph, save_tensor_data); return result; diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index d6b075e0cf0..61cff350bde 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -232,11 +232,13 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T if (!tensor_proto.has_data_type()) { MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; - MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type."; + MS_LOG(ERROR) << "mind_ir TensorProto has no data_type."; + return nullptr; } if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; - MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!"; + MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!"; + return nullptr; } tensor::TensorPtr tensor_info = @@ -258,7 +260,9 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, node->set_name(debug_info_name); tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); - MS_EXCEPTION_IF_NULL(tensor_info); + if (tensor_info == nullptr) { + return false; + } MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name; if (!IsIncLoad() || load_tensor_map_.find(debug_info_name) == load_tensor_map_.end()) { load_tensor_map_[debug_info_name] = tensor_info; @@ -311,7 +315,9 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi if (value_proto.tensor_size() > 0) { const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); - MS_EXCEPTION_IF_NULL(tensor_info); + if (tensor_info == nullptr) { + return false; + } auto tensor_abstract = tensor_info->ToAbstract(); node->set_abstract(tensor_abstract); } else if (value_proto.has_denotation()) { @@ -797,7 +803,8 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) { auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize)); if (anfNode == nullptr) { - MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type; + MS_LOG(ERROR) << "Can't find the ref:" << node_type; + return nullptr; } return anfNode; } @@ -828,7 +835,8 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr continue; } if (!GetAttrValueForCNode(prim, attr_proto)) { - MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString(); + MS_LOG(ERROR) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString(); + return nullptr; } } prim->set_attr("is_load", MakeValue(true)); @@ -919,7 +927,12 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc MS_LOG(DEBUG) << "Process CNode: " << node_name; // Build inputs. std::vector inputs; - inputs.push_back(BuildOperatorNode(node_proto)); + auto operator_node = BuildOperatorNode(node_proto); + if (operator_node == nullptr) { + MS_LOG(ERROR) << "Build operator node " << node_name << " failed!"; + return nullptr; + } + inputs.push_back(operator_node); for (int i = 0; i < node_proto.input_size(); ++i) { auto anfNode = GetAnfNode(node_proto.input(i)); if (anfNode == nullptr) { @@ -949,7 +962,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); - if (importProto.output_size() < 0 || importProto.output_size() > INT_MAX) { + if (importProto.output_size() <= 0 || importProto.output_size() > INT_MAX) { MS_LOG(ERROR) << "importProto.output_size is : " << importProto.output_size(); return false; } @@ -1089,10 +1102,12 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { FuncGraphPtr graph = std::make_shared(); const auto &graph_proto = model_proto.functions(i); if (!graph_proto.has_name()) { - MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. "; + MS_LOG(ERROR) << "The function has not a name. Please export mindIR again. "; + return nullptr; } if (anfnode_build_map_.count(graph_proto.name()) > 0) { - MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name(); + MS_LOG(ERROR) << "There is a duplication function graph name: " << graph_proto.name(); + return nullptr; } anfnode_build_map_[graph_proto.name()] = std::make_shared(graph); } diff --git a/tests/st/compile_cache/test_ascend_lenet.py b/tests/st/compile_cache/test_ascend_lenet.py index aedaef288c0..b610657e699 100644 --- a/tests/st/compile_cache/test_ascend_lenet.py +++ b/tests/st/compile_cache/test_ascend_lenet.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import os +import pathlib import numpy as np import pytest @@ -93,3 +94,27 @@ def test_lenet(): net1 = LeNet() train(net1, data1, label1) context.set_context(save_compile_cache=False, load_compile_cache=False) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_lenet_load_failed(): + path = "compile_cache.mindir" + if os.path.exists(path): + os.remove(path) + assert not os.path.exists(path) + data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + train(net, data, label) + assert os.path.exists(path) + os.remove(path) + pathlib.Path(path).touch() + + data1 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + label1 = Tensor(np.ones([32]).astype(np.int32)) + net1 = LeNet() + train(net1, data1, label1) + context.set_context(save_compile_cache=False, load_compile_cache=False) diff --git a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc index 57f64f7dc48..64ab105d04d 100644 --- a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc +++ b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc @@ -26,8 +26,8 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; } -mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { - mind_ir::ModelProto empty_model; +ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { + ModelProtoPtr empty_model; return empty_model; } } // namespace mindspore