Continue execution when saving and loading mindir failed

This commit is contained in:
yujianfeng 2021-10-09 11:02:37 +08:00
parent ab5c15074a
commit d384db6c01
7 changed files with 342 additions and 137 deletions

View File

@ -17,19 +17,21 @@
#define MINDSPORE_CCSRC_DEBUG_DUMP_PROTO_H_
#include <string>
#include <memory>
#include "ir/func_graph.h"
#include "proto/mind_ir.pb.h"
#include "debug/common.h"
namespace mindspore {
using ModelProtoPtr = std::shared_ptr<mind_ir::ModelProto>;
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
} // namespace mindspore

View File

@ -158,10 +158,21 @@ void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_grap
return;
}
std::ofstream fout(bprop_mindir_realpath.value());
mind_ir::ModelProto fg_model = GetBinaryProto(func_graph, false);
if (!fg_model.SerializeToOstream(&fout)) {
MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
<< bprop_mindir_realpath.value() << "\".";
if (!fout.is_open()) {
MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno);
return;
}
ModelProtoPtr fg_model = GetBinaryProto(func_graph, false);
if (fg_model == nullptr) {
MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
fout.close();
return;
}
if (!fg_model->SerializeToOstream(&fout)) {
MS_LOG(ERROR) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
<< bprop_mindir_realpath.value() << "\".";
fout.close();
return;
}
fout.close();
ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);

View File

@ -189,7 +189,8 @@ void GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_na
MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
FuncGraphPtr fg = mindspore::LoadMindIR(realpath.value());
if (fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
MS_LOG(ERROR) << "Failed to load the compilation cache file: " << realpath.value();
return;
}
FuncGraphManagerPtr mng = fg->manager();
if (mng == nullptr) {
@ -213,18 +214,27 @@ void CacheFuncGraph(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto realpath = Common::CreatePrefixPath(kCompileCacheFilePath);
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
MS_LOG(ERROR) << "Get real path failed. filename=" << kCompileCacheFilePath;
return;
}
ChangeFileMode(realpath.value(), S_IRWXU);
std::ofstream fout(realpath.value());
if (!fout.is_open()) {
MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
return;
}
FuncGraphPtr fg = resource->func_graph();
mind_ir::ModelProto fg_model = GetBinaryProto(fg, true);
if (!fg_model.SerializeToOstream(&fout)) {
MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value();
ModelProtoPtr fg_model = GetBinaryProto(fg, true);
if (fg_model == nullptr) {
MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
fout.close();
return;
}
if (!fg_model->SerializeToOstream(&fout)) {
MS_LOG(ERROR) << "Failed to cache the graph to file " << realpath.value();
fout.close();
return;
}
fout.close();
ChangeFileMode(realpath.value(), S_IRUSR);

View File

@ -27,6 +27,7 @@
#include "base/core_ops.h"
#include "proto/mind_ir.pb.h"
#include "utils/check_convert_utils.h"
#include "debug/dump_proto.h"
namespace mindspore {
using FloatPtr = std::shared_ptr<Float>;
@ -78,7 +79,7 @@ class IrExporter {
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
virtual ~IrExporter() = default;
std::string GetDumpString(const FuncGraphPtr &func_graph);
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
private:
IrExportBuilderPtr builder_;
@ -86,39 +87,38 @@ class IrExporter {
class IrExportBuilder {
public:
IrExportBuilder() = default;
IrExportBuilder() : model_(std::make_shared<mind_ir::ModelProto>()) {}
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
std::string GetProtoString(const FuncGraphPtr &func_graph);
void BuildModelInfo();
void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
mind_ir::ModelProto Model() { return model_; }
bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
ModelProtoPtr Model() { return model_; }
private:
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data = false);
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data = false);
void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto);
void SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
bool SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
@ -134,7 +134,7 @@ class IrExportBuilder {
void ResetTupleIndex() { shape_index_ = 0; }
private:
mind_ir::ModelProto model_;
ModelProtoPtr model_;
mind_ir::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, std::string> node_index_map_;
@ -147,11 +147,14 @@ class IrExportBuilder {
using IrExporterPtr = std::shared_ptr<IrExporter>;
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
(void)GetDumpProto(func_graph);
auto dump_proto = GetDumpProto(func_graph);
if (dump_proto == nullptr) {
MS_LOG(EXCEPTION) << "Get dump proto for graph " << func_graph->ToString() << " failed.";
}
return builder_->GetProtoString(func_graph);
}
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
if ((builder_ == nullptr) || (func_graph == nullptr)) {
MS_LOG(EXCEPTION) << "Input params is null.";
}
@ -160,26 +163,28 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo
builder_->BuildModelInfo();
// Export model and return string
builder_->BuildModel(func_graph, save_tensor_data);
if (!builder_->BuildModel(func_graph, save_tensor_data)) {
return nullptr;
}
return builder_->Model();
}
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "BuildModel complete!";
return model_.SerializeAsString();
return model_->SerializeAsString();
}
void IrExportBuilder::BuildModelInfo() {
constexpr auto ir_version = "0.1.0";
constexpr auto mindspore_name = "MindSpore";
model_.set_ir_version(ir_version);
model_.set_producer_name(mindspore_name);
model_.set_model_version(VERSION);
model_->set_ir_version(ir_version);
model_->set_producer_name(mindspore_name);
model_->set_model_version(VERSION);
}
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
MS_EXCEPTION_IF_NULL(func_graph);
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
mind_ir::GraphProto *graph_proto = model_->mutable_graph();
graph_proto->set_name(func_graph->ToString());
graph_proto->set_bprop_hash(func_graph->bprop_hash());
ResetNodeIndex();
@ -188,7 +193,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
// Build the main funcGraph
nodeName_.insert(func_graph->ToString());
top_graph = true;
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) {
MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed.";
return false;
}
std::set<FuncGraphPtr> graphVisited;
graphVisited.insert(func_graph);
top_graph = false;
@ -199,32 +208,40 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
continue;
}
if (nodeName_.count(fg->ToString()) > 0) {
MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
MS_LOG(ERROR) << "There is a duplicate name: " << fg->ToString();
return false;
}
nodeName_.insert(fg->ToString());
graphVisited.insert(fg);
auto graph = model_.add_functions();
BuildFuncGraph(fg, graph, save_tensor_data);
auto graph = model_->add_functions();
if (!BuildFuncGraph(fg, graph, save_tensor_data)) {
MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed.";
return false;
}
}
// Release resource
nodeName_.clear();
node_index_map_.clear();
return true;
}
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) {
// Export funcGraph name.
graph_proto->set_name(func_graph->ToString());
// Export parameters
// 1. parameters should be mapped to ValueInfoProto
// 2. parameters with default value should be mapped to Initializer
BuildParameters(func_graph, graph_proto, save_tensor_data);
if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) {
MS_LOG(ERROR) << "Build parameters failed.";
return false;
}
// Export operator nodes(include output)
BuildNodes(func_graph, graph_proto);
return BuildNodes(func_graph, graph_proto);
}
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_proto);
@ -232,14 +249,18 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
MS_EXCEPTION_IF_NULL(item);
auto param = item->cast<ParameterPtr>();
if (param == nullptr) {
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
MS_LOG(ERROR) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
return false;
}
std::string param_name = GetUniqueNodeName(param);
if (top_graph && param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get();
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name);
SetParamToTensorProto(param, parameter_proto);
if (!SetParamToTensorProto(param, parameter_proto)) {
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
return false;
}
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
if (tensor && save_tensor_data) {
parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
@ -247,19 +268,25 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
} else {
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
if (!SetValueInfoProto(param, input_proto)) {
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
return false;
}
}
if (nodeName_.count(param_name) > 0) {
MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
MS_LOG(ERROR) << "parameter name is duplicate:" << param_name;
return false;
}
nodeName_.insert(param_name);
}
return true;
}
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) {
auto iter = g_data_type_map.find(type_id);
if (iter == g_data_type_map.end()) {
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
MS_LOG(ERROR) << "Convert type error, unsupported type! " << type_id;
return mind_ir::TensorProto_DataType_UNDEFINED;
}
return iter->second;
}
@ -267,7 +294,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id)
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) {
auto iter = g_data_bits_int_map.find(bits);
if (iter == g_data_bits_int_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
MS_LOG(ERROR) << "Convert bits int error, unsupported bits! " << bits;
return mind_ir::TensorProto_DataType_UNDEFINED;
}
return iter->second;
}
@ -275,7 +303,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
auto iter = g_data_bits_uint_map.find(bits);
if (iter == g_data_bits_uint_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
MS_LOG(ERROR) << "Convert bits uint error, unsupported bits! " << bits;
return mind_ir::TensorProto_DataType_UNDEFINED;
}
return iter->second;
}
@ -283,20 +312,22 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bit
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
auto iter = g_data_bits_float_map.find(bits);
if (iter == g_data_bits_float_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits;
return mind_ir::TensorProto_DataType_UNDEFINED;
}
return iter->second;
}
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
if (node == nullptr || value_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
}
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
const TypePtr &type = node->Type();
const BaseShapePtr &shape = node->Shape();
// For the bprop fg which has not been renormalized.
if (type == nullptr || shape == nullptr) {
return;
return true;
}
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>();
@ -304,7 +335,11 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
auto data_type = GetMindirDataType(elem_type->type_id());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
if (dims.size() == 0) {
MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
} else {
@ -320,9 +355,10 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
value_proto->set_denotation(type->type_name());
}
MS_LOG(DEBUG) << "Value type: " << type->type_name();
return true;
}
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
@ -335,34 +371,45 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
tensor_proto->set_data_type(GetMindirDataType(dtype));
auto data_type = GetMindirDataType(dtype);
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
for (const auto &dim : shape) {
tensor_proto->add_dims(dim);
}
return true;
}
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
mind_ir::TensorProto *const tensor_proto) {
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString();
return false;
}
auto tensor = type->cast<TensorTypePtr>();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id()));
auto data_type = GetMindirDataType(tensor->element()->type_id());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
for (const auto &dim : dims) {
tensor_proto->add_dims(dim);
}
return true;
}
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto) {
bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto) {
if (param == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
}
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
SetTensorProto(param->Type(), param->Shape(), tensor_proto);
return SetTensorProto(param->Type(), param->Shape(), tensor_proto);
}
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
@ -372,24 +419,36 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
}
auto cnode = node->cast<CNodePtr>();
if (cnode == func_graph->get_return()) {
BuildOutput(cnode, graph_proto);
if (!BuildOutput(cnode, graph_proto)) {
MS_LOG(ERROR) << "Build output for graph " << func_graph->ToString() << " failed.";
return false;
}
} else {
BuildCNode(cnode, graph_proto);
if (!BuildCNode(cnode, graph_proto)) {
MS_LOG(ERROR) << "Build proto for cnode " << cnode->DebugString() << " failed.";
return false;
}
}
}
return true;
}
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
bool IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
MS_EXCEPTION_IF_NULL(node);
const int OutputSize = 2;
if (node->size() != OutputSize) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
MS_LOG(ERROR) << "Number of inputs of return node is not equal to 2.";
return false;
}
AnfNodePtr arg = node->input(1);
std::string node_name = BuildInputNode(arg, graph_proto);
if (node_name.empty()) {
MS_LOG(ERROR) << "Build input node failed for arg " << arg->DebugString();
return false;
}
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
output_proto->set_name(node_name);
SetValueInfoProto(arg, output_proto);
return SetValueInfoProto(arg, output_proto);
}
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
@ -408,16 +467,18 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
auto nodeName = GetUniqueNodeName(node);
type_name = "REF::" + nodeName;
if (nodeName_.count(nodeName) == 0) {
MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
MS_LOG(ERROR) << "There is not the name: " << nodeName;
return "";
}
} else {
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
MS_LOG(ERROR) << "Need to support op type: " << node->type_name();
return "";
}
MS_LOG(DEBUG) << "ExportType: " << type_name;
return type_name;
}
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(shape);
@ -427,7 +488,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
for (size_t i = 0; i < elements.size(); i++) {
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) {
return false;
}
}
*seq_string += "],";
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
@ -435,7 +498,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
*seq_string += shape_name + ",";
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name(shape_name);
SetTensorProto(type, shape, tensor_proto);
return SetTensorProto(type, shape, tensor_proto);
} else if (type->isa<Number>()) {
if (type->isa<Bool>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
@ -453,11 +516,13 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
*seq_string += type->type_name() + ",";
} else {
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
return false;
}
return true;
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
// Get shape of cnode
// 1. need to get shape from tuple element
// 2. save shape in TensorProto
@ -465,21 +530,27 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
MS_EXCEPTION_IF_NULL(node);
auto type = node->Type();
auto shape = node->Shape();
// For the bprop fg which has not been renormalized.
if (type == nullptr || shape == nullptr) {
return;
return true;
}
ResetTupleIndex();
std::string seq_string = "shape:";
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) {
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed.";
return false;
}
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
return true;
}
void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
auto inputs_size = node->size();
if (inputs_size < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
MS_LOG(ERROR) << "Inputs of node " << node->DebugString() << " is empty";
return false;
}
// Need to build input node before dealing with cnode
@ -488,7 +559,12 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
for (size_t i = 1; i < inputs_size; i++) {
auto input = node->input(i);
op_inputs.push_back(input);
input_names.push_back(BuildInputNode(input, graph_proto));
std::string node_name = BuildInputNode(input, graph_proto);
if (node_name.empty()) {
MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
return false;
}
input_names.push_back(node_name);
}
// Build cnode
@ -503,10 +579,16 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
node_proto->set_domain(node->fullname_with_scope());
AnfNodePtr op = node->input(0);
std::string type_name = GetOpTypeName(op);
if (type_name.empty()) {
MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed.";
return false;
}
node_proto->set_op_type(type_name);
last_node_ = node_proto;
// Maybe Tensor or Function or nullptr
SetShapeToNodeProto(node, node_proto);
if (!SetShapeToNodeProto(node, node_proto)) {
return false;
}
(void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); });
@ -520,9 +602,13 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
attr_proto->set_name(attr.first);
auto attr_value = attr.second;
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
SetValueToAttributeProto(attr_value, attr_proto);
if (!SetValueToAttributeProto(attr_value, attr_proto)) {
MS_LOG(ERROR) << "Set value to AttributeProto failed.";
return false;
}
}
}
return true;
}
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
@ -538,7 +624,9 @@ std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::Gra
mind_ir::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_name(node_name);
node_proto->add_output(node_name);
SetAttributeProto(node, node_proto);
if (!SetAttributeProto(node, node_proto)) {
return "";
}
}
return node_name;
}
@ -581,7 +669,7 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
return node_name;
}
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
bool IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
if (node == nullptr || node_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
}
@ -592,10 +680,10 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
SetValueToAttributeProto(value, attr_proto);
return SetValueToAttributeProto(value, attr_proto);
}
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
bool IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
@ -605,17 +693,29 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
auto data_type = GetMindirDataBitsIntType(int_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<UInt>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<UIntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
auto data_type = GetMindirDataBitsUIntType(float_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<Float>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<FloatPtr>();
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<Bool>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
@ -626,35 +726,49 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
auto elem_type = value->cast<TensorTypePtr>()->element();
if (elem_type->isa<Int>()) {
auto int_value = elem_type->cast<IntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
auto data_type = GetMindirDataBitsIntType(int_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (elem_type->isa<Float>()) {
auto float_value = elem_type->cast<FloatPtr>();
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else {
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
MS_LOG(ERROR) << "Unsupported type " << elem_type->type_name();
return false;
}
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
return false;
}
return true;
}
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
if (value->isa<StringImm>() || value->isa<Scalar>()) {
SetScalarToAttributeProto_ir(value, attr_proto);
return SetScalarToAttributeProto_ir(value, attr_proto);
} else if (value->isa<Number>() || value->isa<TensorType>()) {
SetTypeToAttributeProto(value, attr_proto);
return SetTypeToAttributeProto(value, attr_proto);
} else if (value->isa<ValueSequeue>()) {
ResetTupleIndex();
std::string seq_string = "scalar:";
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
if (!SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string)) {
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
return false;
}
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "Attr string: " << seq_string;
} else if (value->isa<tensor::Tensor>()) {
SetTensorToAttributeProto(value, attr_proto);
return SetTensorToAttributeProto(value, attr_proto);
} else if (value->isa<None>()) {
attr_proto->set_ref_attr_name("none");
MS_LOG(DEBUG) << "Attr string: " << value->type_name();
@ -664,14 +778,17 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A
} else if (value->isa<IOMonad>()) {
attr_proto->set_ref_attr_name("Monad:IOMonad");
} else {
MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name();
MS_LOG(ERROR) << "Unsupported Monad type: " << value->type_name();
return false;
}
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
MS_LOG(ERROR) << "Unsupported type: " << value->type_name();
return false;
}
return true;
}
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
@ -714,13 +831,15 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
attr_proto->set_d(GetValue<double>(value));
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
SetTensorToAttributeProto(value, attr_proto);
return SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
return false;
}
return true;
}
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
@ -728,12 +847,20 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
auto data_type = GetMindirDataBitsIntType(int_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<Float>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
auto float_value = value->cast<FloatPtr>();
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<StringImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
attr_proto->add_strings(GetValue<std::string>(value));
@ -772,22 +899,24 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
attr_proto->add_doubles(GetValue<double>(value));
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
SetTensorToAttributeProto(value, attr_proto);
return SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
return false;
}
return true;
}
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string) {
string value_name = "value" + std::to_string(GetTupleIndex());
if (seq_string != nullptr) {
*seq_string += value_name + ",";
}
SetScalarToAttributeProto_irs(value, attr_proto);
return SetScalarToAttributeProto_irs(value, attr_proto);
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string) {
if (value == nullptr || attr_proto == nullptr) {
@ -799,13 +928,19 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
if (tuple_value->value().size() == 0) {
*seq_string += "],";
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
return;
return true;
}
for (const auto &item : tuple_value->value()) {
if (item->isa<ValueTuple>()) {
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
if (!SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string)) {
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
return false;
}
} else {
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) {
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
return false;
}
}
}
*seq_string += "],";
@ -815,18 +950,25 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
if (list_value->value().size() == 0) {
*seq_string += "],";
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
return;
return true;
}
for (const auto &item : list_value->value()) {
MS_EXCEPTION_IF_NULL(item);
if (item->isa<ValueList>()) {
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
if (!SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string)) {
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
return false;
}
} else {
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) {
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
return false;
}
}
}
*seq_string += "],";
}
return true;
}
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
@ -842,7 +984,7 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
return exporter->GetDumpString(func_graph);
}
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
return result;

View File

@ -232,11 +232,13 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
if (!tensor_proto.has_data_type()) {
MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type.";
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type.";
return nullptr;
}
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
return nullptr;
}
tensor::TensorPtr tensor_info =
@ -258,7 +260,9 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
node->set_name(debug_info_name);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
if (tensor_info == nullptr) {
return false;
}
MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
if (!IsIncLoad() || load_tensor_map_.find(debug_info_name) == load_tensor_map_.end()) {
load_tensor_map_[debug_info_name] = tensor_info;
@ -311,7 +315,9 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
if (value_proto.tensor_size() > 0) {
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
if (tensor_info == nullptr) {
return false;
}
auto tensor_abstract = tensor_info->ToAbstract();
node->set_abstract(tensor_abstract);
} else if (value_proto.has_denotation()) {
@ -797,7 +803,8 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize));
if (anfNode == nullptr) {
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
MS_LOG(ERROR) << "Can't find the ref:" << node_type;
return nullptr;
}
return anfNode;
}
@ -828,7 +835,8 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
MS_LOG(ERROR) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
return nullptr;
}
}
prim->set_attr("is_load", MakeValue(true));
@ -919,7 +927,12 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
MS_LOG(DEBUG) << "Process CNode: " << node_name;
// Build inputs.
std::vector<AnfNodePtr> inputs;
inputs.push_back(BuildOperatorNode(node_proto));
auto operator_node = BuildOperatorNode(node_proto);
if (operator_node == nullptr) {
MS_LOG(ERROR) << "Build operator node " << node_name << " failed!";
return nullptr;
}
inputs.push_back(operator_node);
for (int i = 0; i < node_proto.input_size(); ++i) {
auto anfNode = GetAnfNode(node_proto.input(i));
if (anfNode == nullptr) {
@ -949,7 +962,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (importProto.output_size() < 0 || importProto.output_size() > INT_MAX) {
if (importProto.output_size() <= 0 || importProto.output_size() > INT_MAX) {
MS_LOG(ERROR) << "importProto.output_size is : " << importProto.output_size();
return false;
}
@ -1089,10 +1102,12 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
FuncGraphPtr graph = std::make_shared<FuncGraph>();
const auto &graph_proto = model_proto.functions(i);
if (!graph_proto.has_name()) {
MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. ";
MS_LOG(ERROR) << "The function has not a name. Please export mindIR again. ";
return nullptr;
}
if (anfnode_build_map_.count(graph_proto.name()) > 0) {
MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name();
MS_LOG(ERROR) << "There is a duplication function graph name: " << graph_proto.name();
return nullptr;
}
anfnode_build_map_[graph_proto.name()] = std::make_shared<ValueNode>(graph);
}

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import os
import pathlib
import numpy as np
import pytest
@ -93,3 +94,27 @@ def test_lenet():
net1 = LeNet()
train(net1, data1, label1)
context.set_context(save_compile_cache=False, load_compile_cache=False)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_lenet_load_failed():
path = "compile_cache.mindir"
if os.path.exists(path):
os.remove(path)
assert not os.path.exists(path)
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)
assert os.path.exists(path)
os.remove(path)
pathlib.Path(path).touch()
data1 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label1 = Tensor(np.ones([32]).astype(np.int32))
net1 = LeNet()
train(net1, data1, label1)
context.set_context(save_compile_cache=False, load_compile_cache=False)

View File

@ -26,8 +26,8 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
mind_ir::ModelProto empty_model;
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
ModelProtoPtr empty_model;
return empty_model;
}
} // namespace mindspore