|
|
|
@ -27,6 +27,7 @@
|
|
|
|
|
#include "base/core_ops.h"
|
|
|
|
|
#include "proto/mind_ir.pb.h"
|
|
|
|
|
#include "utils/check_convert_utils.h"
|
|
|
|
|
#include "debug/dump_proto.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
using FloatPtr = std::shared_ptr<Float>;
|
|
|
|
@ -78,7 +79,7 @@ class IrExporter {
|
|
|
|
|
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
|
|
|
|
|
virtual ~IrExporter() = default;
|
|
|
|
|
std::string GetDumpString(const FuncGraphPtr &func_graph);
|
|
|
|
|
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
|
|
|
|
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
IrExportBuilderPtr builder_;
|
|
|
|
@ -86,39 +87,38 @@ class IrExporter {
|
|
|
|
|
|
|
|
|
|
class IrExportBuilder {
|
|
|
|
|
public:
|
|
|
|
|
IrExportBuilder() = default;
|
|
|
|
|
IrExportBuilder() : model_(std::make_shared<mind_ir::ModelProto>()) {}
|
|
|
|
|
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
|
|
|
|
|
std::string GetProtoString(const FuncGraphPtr &func_graph);
|
|
|
|
|
void BuildModelInfo();
|
|
|
|
|
void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
|
|
|
|
mind_ir::ModelProto Model() { return model_; }
|
|
|
|
|
bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
|
|
|
|
|
ModelProtoPtr Model() { return model_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool save_tensor_data = false);
|
|
|
|
|
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool save_tensor_data = false);
|
|
|
|
|
void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
|
|
|
|
|
|
|
|
|
void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
|
|
|
|
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto);
|
|
|
|
|
void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
|
|
|
|
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
|
|
|
|
|
void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
|
|
|
|
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
|
|
|
|
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
|
|
|
|
|
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
|
|
|
|
|
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
|
|
|
|
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
std::string *const seq_string);
|
|
|
|
|
void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
|
|
|
|
bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
std::string *const seq_string);
|
|
|
|
|
void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
bool SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
std::string *const seq_string);
|
|
|
|
|
|
|
|
|
|
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
|
|
|
|
@ -134,7 +134,7 @@ class IrExportBuilder {
|
|
|
|
|
void ResetTupleIndex() { shape_index_ = 0; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
mind_ir::ModelProto model_;
|
|
|
|
|
ModelProtoPtr model_;
|
|
|
|
|
mind_ir::NodeProto *last_node_{nullptr};
|
|
|
|
|
std::list<FuncGraphPtr> todo_;
|
|
|
|
|
std::map<AnfNodePtr, std::string> node_index_map_;
|
|
|
|
@ -147,11 +147,14 @@ class IrExportBuilder {
|
|
|
|
|
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
|
|
|
|
|
|
|
|
|
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
|
|
|
|
(void)GetDumpProto(func_graph);
|
|
|
|
|
auto dump_proto = GetDumpProto(func_graph);
|
|
|
|
|
if (dump_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Get dump proto for graph " << func_graph->ToString() << " failed.";
|
|
|
|
|
}
|
|
|
|
|
return builder_->GetProtoString(func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
|
|
|
|
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
|
|
|
|
if ((builder_ == nullptr) || (func_graph == nullptr)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input params is null.";
|
|
|
|
|
}
|
|
|
|
@ -160,26 +163,28 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo
|
|
|
|
|
builder_->BuildModelInfo();
|
|
|
|
|
|
|
|
|
|
// Export model and return string
|
|
|
|
|
builder_->BuildModel(func_graph, save_tensor_data);
|
|
|
|
|
if (!builder_->BuildModel(func_graph, save_tensor_data)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return builder_->Model();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_LOG(DEBUG) << "BuildModel complete!";
|
|
|
|
|
return model_.SerializeAsString();
|
|
|
|
|
return model_->SerializeAsString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildModelInfo() {
|
|
|
|
|
constexpr auto ir_version = "0.1.0";
|
|
|
|
|
constexpr auto mindspore_name = "MindSpore";
|
|
|
|
|
model_.set_ir_version(ir_version);
|
|
|
|
|
model_.set_producer_name(mindspore_name);
|
|
|
|
|
model_.set_model_version(VERSION);
|
|
|
|
|
model_->set_ir_version(ir_version);
|
|
|
|
|
model_->set_producer_name(mindspore_name);
|
|
|
|
|
model_->set_model_version(VERSION);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
|
|
|
|
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
|
|
|
|
|
mind_ir::GraphProto *graph_proto = model_->mutable_graph();
|
|
|
|
|
graph_proto->set_name(func_graph->ToString());
|
|
|
|
|
graph_proto->set_bprop_hash(func_graph->bprop_hash());
|
|
|
|
|
ResetNodeIndex();
|
|
|
|
@ -188,7 +193,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|
|
|
|
// Build the main funcGraph
|
|
|
|
|
nodeName_.insert(func_graph->ToString());
|
|
|
|
|
top_graph = true;
|
|
|
|
|
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
|
|
|
|
|
if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) {
|
|
|
|
|
MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<FuncGraphPtr> graphVisited;
|
|
|
|
|
graphVisited.insert(func_graph);
|
|
|
|
|
top_graph = false;
|
|
|
|
@ -199,32 +208,40 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (nodeName_.count(fg->ToString()) > 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
|
|
|
|
|
MS_LOG(ERROR) << "There is a duplicate name: " << fg->ToString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
nodeName_.insert(fg->ToString());
|
|
|
|
|
graphVisited.insert(fg);
|
|
|
|
|
auto graph = model_.add_functions();
|
|
|
|
|
BuildFuncGraph(fg, graph, save_tensor_data);
|
|
|
|
|
auto graph = model_->add_functions();
|
|
|
|
|
if (!BuildFuncGraph(fg, graph, save_tensor_data)) {
|
|
|
|
|
MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Release resource
|
|
|
|
|
nodeName_.clear();
|
|
|
|
|
node_index_map_.clear();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool save_tensor_data) {
|
|
|
|
|
// Export funcGraph name.
|
|
|
|
|
graph_proto->set_name(func_graph->ToString());
|
|
|
|
|
// Export parameters
|
|
|
|
|
// 1. parameters should be mapped to ValueInfoProto
|
|
|
|
|
// 2. parameters with default value should be mapped to Initializer
|
|
|
|
|
BuildParameters(func_graph, graph_proto, save_tensor_data);
|
|
|
|
|
if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) {
|
|
|
|
|
MS_LOG(ERROR) << "Build parameters failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Export operator nodes(include output)
|
|
|
|
|
BuildNodes(func_graph, graph_proto);
|
|
|
|
|
return BuildNodes(func_graph, graph_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
|
|
|
|
bool save_tensor_data) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_proto);
|
|
|
|
@ -232,14 +249,18 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item);
|
|
|
|
|
auto param = item->cast<ParameterPtr>();
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
|
|
|
|
MS_LOG(ERROR) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::string param_name = GetUniqueNodeName(param);
|
|
|
|
|
if (top_graph && param->has_default()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get();
|
|
|
|
|
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
|
|
|
|
parameter_proto->set_name(param_name);
|
|
|
|
|
SetParamToTensorProto(param, parameter_proto);
|
|
|
|
|
if (!SetParamToTensorProto(param, parameter_proto)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
|
|
|
|
|
if (tensor && save_tensor_data) {
|
|
|
|
|
parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
|
|
|
|
@ -247,19 +268,25 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|
|
|
|
} else {
|
|
|
|
|
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
|
|
|
|
input_proto->set_name(param_name);
|
|
|
|
|
SetValueInfoProto(param, input_proto);
|
|
|
|
|
if (!SetValueInfoProto(param, input_proto)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (nodeName_.count(param_name) > 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
|
|
|
|
|
MS_LOG(ERROR) << "parameter name is duplicate:" << param_name;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
nodeName_.insert(param_name);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) {
|
|
|
|
|
auto iter = g_data_type_map.find(type_id);
|
|
|
|
|
if (iter == g_data_type_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
|
|
|
|
|
MS_LOG(ERROR) << "Convert type error, unsupported type! " << type_id;
|
|
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
@ -267,7 +294,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id)
|
|
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) {
|
|
|
|
|
auto iter = g_data_bits_int_map.find(bits);
|
|
|
|
|
if (iter == g_data_bits_int_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
|
|
|
|
|
MS_LOG(ERROR) << "Convert bits int error, unsupported bits! " << bits;
|
|
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
@ -275,7 +303,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits
|
|
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
|
|
|
|
|
auto iter = g_data_bits_uint_map.find(bits);
|
|
|
|
|
if (iter == g_data_bits_uint_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
|
|
|
|
|
MS_LOG(ERROR) << "Convert bits uint error, unsupported bits! " << bits;
|
|
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
@ -283,20 +312,22 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bit
|
|
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
|
|
|
|
|
auto iter = g_data_bits_float_map.find(bits);
|
|
|
|
|
if (iter == g_data_bits_float_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
|
|
|
|
|
MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits;
|
|
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED;
|
|
|
|
|
}
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
|
|
|
|
|
bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
|
|
|
|
|
if (node == nullptr || value_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
|
|
|
|
|
const TypePtr &type = node->Type();
|
|
|
|
|
const BaseShapePtr &shape = node->Shape();
|
|
|
|
|
// For the bprop fg which has not been renormalized.
|
|
|
|
|
if (type == nullptr || shape == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
|
|
|
|
auto tensor = type->cast<TensorTypePtr>();
|
|
|
|
@ -304,7 +335,11 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|
|
|
|
auto elem_type = tensor->element();
|
|
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
|
|
|
|
|
auto data_type = GetMindirDataType(elem_type->type_id());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
if (dims.size() == 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
|
|
|
|
|
} else {
|
|
|
|
@ -320,9 +355,10 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|
|
|
|
value_proto->set_denotation(type->type_name());
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Value type: " << type->type_name();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
@ -335,34 +371,45 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
|
|
|
|
|
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
|
|
|
|
|
auto dtype = data->data_type();
|
|
|
|
|
auto shape = data->shape_c();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataType(dtype));
|
|
|
|
|
auto data_type = GetMindirDataType(dtype);
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
for (const auto &dim : shape) {
|
|
|
|
|
tensor_proto->add_dims(dim);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
mind_ir::TensorProto *const tensor_proto) {
|
|
|
|
|
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
|
|
|
|
|
MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto tensor = type->cast<TensorTypePtr>();
|
|
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id()));
|
|
|
|
|
auto data_type = GetMindirDataType(tensor->element()->type_id());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
for (const auto &dim : dims) {
|
|
|
|
|
tensor_proto->add_dims(dim);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) {
|
|
|
|
|
bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) {
|
|
|
|
|
if (param == nullptr || tensor_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
|
|
|
|
|
SetTensorProto(param->Type(), param->Shape(), tensor_proto);
|
|
|
|
|
return SetTensorProto(param->Type(), param->Shape(), tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
|
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
|
|
|
|
for (const AnfNodePtr &node : nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
@ -372,24 +419,36 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode == func_graph->get_return()) {
|
|
|
|
|
BuildOutput(cnode, graph_proto);
|
|
|
|
|
if (!BuildOutput(cnode, graph_proto)) {
|
|
|
|
|
MS_LOG(ERROR) << "Build output for graph " << func_graph->ToString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
BuildCNode(cnode, graph_proto);
|
|
|
|
|
if (!BuildCNode(cnode, graph_proto)) {
|
|
|
|
|
MS_LOG(ERROR) << "Build proto for cnode " << cnode->DebugString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
|
bool IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
const int OutputSize = 2;
|
|
|
|
|
if (node->size() != OutputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
|
|
|
|
MS_LOG(ERROR) << "Number of inputs of return node is not equal to 2.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr arg = node->input(1);
|
|
|
|
|
std::string node_name = BuildInputNode(arg, graph_proto);
|
|
|
|
|
if (node_name.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "Build input node failed for arg " << arg->DebugString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
|
|
|
|
|
output_proto->set_name(node_name);
|
|
|
|
|
SetValueInfoProto(arg, output_proto);
|
|
|
|
|
return SetValueInfoProto(arg, output_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|
|
|
@ -408,16 +467,18 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|
|
|
|
auto nodeName = GetUniqueNodeName(node);
|
|
|
|
|
type_name = "REF::" + nodeName;
|
|
|
|
|
if (nodeName_.count(nodeName) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
|
|
|
|
|
MS_LOG(ERROR) << "There is not the name: " << nodeName;
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Need to support op type: " << node->type_name();
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "ExportType: " << type_name;
|
|
|
|
|
return type_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shape);
|
|
|
|
@ -427,7 +488,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|
|
|
|
auto elements = type->cast<TuplePtr>()->elements();
|
|
|
|
|
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
|
|
|
|
|
for (size_t i = 0; i < elements.size(); i++) {
|
|
|
|
|
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
|
|
|
|
|
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*seq_string += "],";
|
|
|
|
|
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
|
|
|
@ -435,7 +498,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|
|
|
|
*seq_string += shape_name + ",";
|
|
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
|
|
|
|
tensor_proto->set_name(shape_name);
|
|
|
|
|
SetTensorProto(type, shape, tensor_proto);
|
|
|
|
|
return SetTensorProto(type, shape, tensor_proto);
|
|
|
|
|
} else if (type->isa<Number>()) {
|
|
|
|
|
if (type->isa<Bool>()) {
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
|
|
|
@ -453,11 +516,13 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|
|
|
|
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
|
|
|
|
*seq_string += type->type_name() + ",";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
|
|
|
|
bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
|
|
|
|
// Get shape of cnode
|
|
|
|
|
// 1. need to get shape from tuple element
|
|
|
|
|
// 2. save shape in TensorProto
|
|
|
|
@ -465,21 +530,27 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto type = node->Type();
|
|
|
|
|
auto shape = node->Shape();
|
|
|
|
|
// For the bprop fg which has not been renormalized.
|
|
|
|
|
if (type == nullptr || shape == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
ResetTupleIndex();
|
|
|
|
|
std::string seq_string = "shape:";
|
|
|
|
|
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
|
|
|
|
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
|
|
|
|
|
if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
attr_proto->set_ref_attr_name(seq_string);
|
|
|
|
|
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
|
bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
|
auto inputs_size = node->size();
|
|
|
|
|
if (inputs_size < 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
|
|
|
|
MS_LOG(ERROR) << "Inputs of node " << node->DebugString() << " is empty";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Need to build input node before dealing with cnode
|
|
|
|
@ -488,7 +559,12 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
|
|
|
|
for (size_t i = 1; i < inputs_size; i++) {
|
|
|
|
|
auto input = node->input(i);
|
|
|
|
|
op_inputs.push_back(input);
|
|
|
|
|
input_names.push_back(BuildInputNode(input, graph_proto));
|
|
|
|
|
std::string node_name = BuildInputNode(input, graph_proto);
|
|
|
|
|
if (node_name.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
input_names.push_back(node_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build cnode
|
|
|
|
@ -503,10 +579,16 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
|
|
|
|
node_proto->set_domain(node->fullname_with_scope());
|
|
|
|
|
AnfNodePtr op = node->input(0);
|
|
|
|
|
std::string type_name = GetOpTypeName(op);
|
|
|
|
|
if (type_name.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
node_proto->set_op_type(type_name);
|
|
|
|
|
last_node_ = node_proto;
|
|
|
|
|
// Maybe Tensor or Function or nullptr
|
|
|
|
|
SetShapeToNodeProto(node, node_proto);
|
|
|
|
|
if (!SetShapeToNodeProto(node, node_proto)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
(void)std::for_each(input_names.begin(), input_names.end(),
|
|
|
|
|
[&node_proto](const string &name) { node_proto->add_input(name); });
|
|
|
|
@ -520,9 +602,13 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
|
|
|
|
attr_proto->set_name(attr.first);
|
|
|
|
|
auto attr_value = attr.second;
|
|
|
|
|
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
|
|
|
|
|
SetValueToAttributeProto(attr_value, attr_proto);
|
|
|
|
|
if (!SetValueToAttributeProto(attr_value, attr_proto)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set value to AttributeProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
|
|
|
@ -538,7 +624,9 @@ std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::Gra
|
|
|
|
|
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
|
|
|
|
node_proto->set_name(node_name);
|
|
|
|
|
node_proto->add_output(node_name);
|
|
|
|
|
SetAttributeProto(node, node_proto);
|
|
|
|
|
if (!SetAttributeProto(node, node_proto)) {
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return node_name;
|
|
|
|
|
}
|
|
|
|
@ -581,7 +669,7 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
|
|
|
|
return node_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
|
|
|
|
bool IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
|
|
|
|
if (node == nullptr || node_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
|
|
|
|
|
}
|
|
|
|
@ -592,10 +680,10 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro
|
|
|
|
|
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
|
|
|
|
attr_proto->set_name("value");
|
|
|
|
|
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
|
|
|
|
|
SetValueToAttributeProto(value, attr_proto);
|
|
|
|
|
return SetValueToAttributeProto(value, attr_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
bool IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
@ -605,17 +693,29 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
|
|
|
|
|
attr_proto->set_ref_attr_name("type:value0");
|
|
|
|
|
tensor_proto->set_name("value0");
|
|
|
|
|
auto int_value = value->cast<IntPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsIntType(int_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else if (value->isa<UInt>()) {
|
|
|
|
|
attr_proto->set_ref_attr_name("type:value0");
|
|
|
|
|
tensor_proto->set_name("value0");
|
|
|
|
|
auto float_value = value->cast<UIntPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsUIntType(float_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else if (value->isa<Float>()) {
|
|
|
|
|
attr_proto->set_ref_attr_name("type:value0");
|
|
|
|
|
tensor_proto->set_name("value0");
|
|
|
|
|
auto float_value = value->cast<FloatPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else if (value->isa<Bool>()) {
|
|
|
|
|
attr_proto->set_ref_attr_name("type:value0");
|
|
|
|
|
tensor_proto->set_name("value0");
|
|
|
|
@ -626,35 +726,49 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
|
|
|
|
|
auto elem_type = value->cast<TensorTypePtr>()->element();
|
|
|
|
|
if (elem_type->isa<Int>()) {
|
|
|
|
|
auto int_value = elem_type->cast<IntPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsIntType(int_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else if (elem_type->isa<Float>()) {
|
|
|
|
|
auto float_value = elem_type->cast<FloatPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported type " << elem_type->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
|
if (value->isa<StringImm>() || value->isa<Scalar>()) {
|
|
|
|
|
SetScalarToAttributeProto_ir(value, attr_proto);
|
|
|
|
|
return SetScalarToAttributeProto_ir(value, attr_proto);
|
|
|
|
|
} else if (value->isa<Number>() || value->isa<TensorType>()) {
|
|
|
|
|
SetTypeToAttributeProto(value, attr_proto);
|
|
|
|
|
return SetTypeToAttributeProto(value, attr_proto);
|
|
|
|
|
} else if (value->isa<ValueSequeue>()) {
|
|
|
|
|
ResetTupleIndex();
|
|
|
|
|
std::string seq_string = "scalar:";
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
|
|
|
|
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
|
|
|
|
|
if (!SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
attr_proto->set_ref_attr_name(seq_string);
|
|
|
|
|
MS_LOG(DEBUG) << "Attr string: " << seq_string;
|
|
|
|
|
} else if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
return SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
} else if (value->isa<None>()) {
|
|
|
|
|
attr_proto->set_ref_attr_name("none");
|
|
|
|
|
MS_LOG(DEBUG) << "Attr string: " << value->type_name();
|
|
|
|
@ -664,14 +778,17 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A
|
|
|
|
|
} else if (value->isa<IOMonad>()) {
|
|
|
|
|
attr_proto->set_ref_attr_name("Monad:IOMonad");
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported Monad type: " << value->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported type: " << value->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
@ -714,13 +831,15 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
|
|
|
|
|
attr_proto->set_d(GetValue<double>(value));
|
|
|
|
|
} else if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
|
|
|
|
|
SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
return SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
|
|
|
|
}
|
|
|
|
@ -728,12 +847,20 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
|
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
|
|
|
|
auto int_value = value->cast<IntPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsIntType(int_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else if (value->isa<Float>()) {
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
|
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
|
|
|
|
auto float_value = value->cast<FloatPtr>();
|
|
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
|
|
|
|
|
auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
|
|
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tensor_proto->set_data_type(data_type);
|
|
|
|
|
} else if (value->isa<StringImm>()) {
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
|
|
|
|
|
attr_proto->add_strings(GetValue<std::string>(value));
|
|
|
|
@ -772,22 +899,24 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
|
|
|
|
|
attr_proto->add_doubles(GetValue<double>(value));
|
|
|
|
|
} else if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
|
|
|
|
|
SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
return SetTensorToAttributeProto(value, attr_proto);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
std::string *const seq_string) {
|
|
|
|
|
string value_name = "value" + std::to_string(GetTupleIndex());
|
|
|
|
|
if (seq_string != nullptr) {
|
|
|
|
|
*seq_string += value_name + ",";
|
|
|
|
|
}
|
|
|
|
|
SetScalarToAttributeProto_irs(value, attr_proto);
|
|
|
|
|
return SetScalarToAttributeProto_irs(value, attr_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|
|
|
|
bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|
|
|
|
mind_ir::AttributeProto *const attr_proto,
|
|
|
|
|
std::string *const seq_string) {
|
|
|
|
|
if (value == nullptr || attr_proto == nullptr) {
|
|
|
|
@ -799,13 +928,19 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|
|
|
|
if (tuple_value->value().size() == 0) {
|
|
|
|
|
*seq_string += "],";
|
|
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
|
|
|
|
|
return;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &item : tuple_value->value()) {
|
|
|
|
|
if (item->isa<ValueTuple>()) {
|
|
|
|
|
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
|
|
|
|
|
if (!SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
|
|
|
|
|
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*seq_string += "],";
|
|
|
|
@ -815,18 +950,25 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|
|
|
|
if (list_value->value().size() == 0) {
|
|
|
|
|
*seq_string += "],";
|
|
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
|
|
|
|
|
return;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &item : list_value->value()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item);
|
|
|
|
|
if (item->isa<ValueList>()) {
|
|
|
|
|
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
|
|
|
|
|
if (!SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
|
|
|
|
|
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) {
|
|
|
|
|
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*seq_string += "],";
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|
|
|
@ -842,7 +984,7 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|
|
|
|
return exporter->GetDumpString(func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
|
|
|
|
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
|
|
|
|
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
|
|
|
|
|
auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
|
|
|
|
|
return result;
|
|
|
|
|