parse tuple value and shape in prim attr

This commit is contained in:
leopz 2020-07-21 14:47:33 +08:00
parent 57252dee24
commit ae1d948792
1 changed files with 144 additions and 92 deletions

View File

@ -92,14 +92,17 @@ class IrExportBuilder {
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto,
std::string suffix = "0");
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string);
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
@ -107,8 +110,10 @@ class IrExportBuilder {
std::string GetNodeName(const AnfNodePtr &node);
std::string GetUniqueNodeName(const AnfNodePtr &node);
std::string GetOpTypeName(const AnfNodePtr &node);
size_t AllocateIndex() { return ++node_index_; }
void ResetIndex() { node_index_ = 0; }
size_t GetNodeIndex() { return ++node_index_; }
void ResetNodeIndex() { node_index_ = 0; }
size_t GetTupleIndex() { return ++shape_index_; }
void ResetTupleIndex() { shape_index_ = 0; }
private:
onnx::ModelProto model_;
@ -116,6 +121,7 @@ class IrExportBuilder {
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
size_t node_index_{0};
size_t shape_index_{0};
};
using IrExporterPtr = std::shared_ptr<IrExporter>;
@ -148,7 +154,7 @@ void IrExportBuilder::BuildModelInfo() {
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
onnx::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString());
ResetIndex();
ResetNodeIndex();
todo_.clear();
todo_.push_back(func_graph);
while (!todo_.empty()) {
@ -179,7 +185,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
if (!param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default";
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default.";
continue;
}
@ -234,13 +240,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
if (dims.size() == 0) {
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1.";
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
} else {
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
}
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
type_proto->set_denotation(std::to_string(tup_shape->shape().size()));
type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
} else if (type->isa<Number>() || type->isa<String>()) {
type_proto->set_denotation(type->type_name());
} else {
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
}
@ -250,9 +263,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("tensor");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
attr_proto->set_ref_attr_name("tensor:value0");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name("value0");
auto data = value->cast<tensor::TensorPtr>();
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
@ -286,6 +300,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, onnx::Ten
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
bool is_only_return = true;
for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
@ -293,9 +308,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt
}
auto cnode = node->cast<CNodePtr>();
if (cnode == func_graph->get_return()) {
if (is_only_return) {
MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!";
}
BuildOutput(cnode, graph_proto);
} else {
BuildCNode(cnode, graph_proto);
is_only_return = false;
}
}
}
@ -305,24 +324,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
}
AnfNodePtr arg = node->input(1);
// Using make_tuple to set multi-output
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
auto tuple_node = arg->cast<CNodePtr>();
for (size_t i = 1; i < tuple_node->size(); i++) {
auto input_node = arg->cast<CNodePtr>()->input(i);
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
auto output_name = GetUniqueNodeName(tuple_node->input(i));
output_proto->set_name(output_name);
last_node_->add_output(output_name);
SetValueInfoProto(tuple_node->input(i), output_proto);
}
} else {
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
last_node_->add_output(output_name);
SetValueInfoProto(arg, output_proto);
}
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
last_node_->set_output(0, output_name);
SetValueInfoProto(arg, output_proto);
}
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
@ -345,43 +351,41 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
}
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::NodeProto *const node_proto, std::string suffix) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_ref_attr_name("shape");
if (suffix.compare("0") != 0) {
attr_proto->set_name("shape" + suffix);
onnx::AttributeProto *const attr_proto, std::string *const seq_string) {
if (type->isa<Tuple>() && seq_string != nullptr) {
*seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
for (size_t i = 0; i < elements.size(); i++) {
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
}
*seq_string += "],";
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) {
string shape_name = "shape" + std::to_string(GetTupleIndex());
*seq_string += shape_name + ",";
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name(shape_name);
SetTensorProto(type, shape, tensor_proto);
} else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) {
*seq_string += type->type_name() + ",";
} else {
attr_proto->set_name("shape");
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
}
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetTensorProto(type, shape, tensor_proto);
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
// Get shape of cnode
// 1. prim ArgMaxWithValue need to get shape from tuple element
// 2. some cnode doesn't has shape, such as LayerNorm
// 3. other cnodes have shape
if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) {
auto type = node->Type();
auto shape = node->Shape();
if (!type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name();
}
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
for (size_t i = 0; i < elements.size(); i++) {
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i));
}
} else {
auto type = node->Type();
auto shape = node->Shape();
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString();
return;
}
SetShapeToNodeProto(type, shape, node_proto);
}
// 1. need to get shape from tuple element
// 2. save shape in TensorProto
// 3. save tuple string in ref_attr_name
auto type = node->Type();
auto shape = node->Shape();
ResetTupleIndex();
std::string seq_string = "shape:";
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
}
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
@ -445,15 +449,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
std::string node_name = "";
if (node->isa<Parameter>()) {
node_name = GetNodeName(node);
} else if (node->isa<CNode>() || node->isa<ValueNode>()) {
} else if (node->isa<CNode>()) {
auto iter = node_index_map_.find(node);
if (iter != node_index_map_.end()) {
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
} else {
auto node_idx = AllocateIndex();
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
}
} else if (node->isa<ValueNode>()) {
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
} else {
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
}
@ -487,17 +495,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("type");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
if (value->isa<Int>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
} else if (value->isa<Float>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<FloatPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
} else if (value->isa<TensorType>()) {
tensor_proto->set_name("tensor");
attr_proto->set_ref_attr_name("type:tensor0");
tensor_proto->set_name("tensor0");
auto elem_type = value->cast<TensorTypePtr>()->element();
if (elem_type->isa<Int>()) {
auto int_value = elem_type->cast<IntPtr>();
@ -521,10 +533,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr
SetScalarToAttributeProto(value, attr_proto);
} else if (value->isa<Number>() || value->isa<TensorType>()) {
SetTypeToAttributeProto(value, attr_proto);
} else if (value->isa<ValueSequeue>()) {
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto);
} else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) {
ResetTupleIndex();
std::string seq_string = "scalar:";
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "Attr string: " << seq_string;
} else if (value->isa<tensor::Tensor>()) {
SetTensorToAttributeProto(value, attr_proto);
} else if (value->isa<None>()) {
attr_proto->set_ref_attr_name("none");
MS_LOG(DEBUG) << "Attr string: " << value->type_name();
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
}
@ -534,16 +554,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetScalarToProto(value, tensor_proto);
attr_proto->set_ref_attr_name("scalar:value0");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
SetScalarToProto(value, tensor_proto, "value0");
}
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto,
const std::string &value_name) {
if (value == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
}
tensor_proto->set_name(value_name);
if (value->isa<StringImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
tensor_proto->add_string_data(GetValue<std::string>(value));
@ -562,44 +584,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto
} else if (value->isa<Int64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
} else if (value->isa<FloatImm>()) {
} else if (value->isa<UInt8Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8);
tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value());
} else if (value->isa<UInt16Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16);
tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value());
} else if (value->isa<UInt32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32);
tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value());
} else if (value->isa<UInt64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64);
tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value());
} else if (value->isa<FP32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
tensor_proto->add_float_data(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
tensor_proto->add_double_data(GetValue<double>(value));
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
}
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
onnx::AttributeProto *const attr_proto) {
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string) {
string value_name = "value" + std::to_string(GetTupleIndex());
if (seq_string != nullptr) {
*seq_string += value_name + ",";
}
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
SetScalarToProto(value, tensor_proto, value_name);
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
std::string *const seq_string) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
if (value->isa<ValueTuple>()) {
if (value->isa<ValueTuple>() && seq_string != nullptr) {
*seq_string += "Tuple[";
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
if (tuple_value->value().size() == 0) {
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
return;
}
auto type_id = tuple_value->value()[0]->type()->type_id();
tensor_proto->set_data_type(GetOnnxDataType(type_id));
for (const auto &item : tuple_value->value()) {
SetScalarToProto(item, tensor_proto);
if (item->isa<ValueTuple>()) {
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
} else {
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
}
}
} else if (value->isa<ValueList>()) {
*seq_string += "],";
} else if (value->isa<ValueList>() && seq_string != nullptr) {
*seq_string += "List[";
const ValueListPtr &list_value = value->cast<ValueListPtr>();
if (list_value->value().size() == 0) {
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0";
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
return;
}
auto type_id = list_value->value()[0]->type()->type_id();
tensor_proto->set_data_type(GetOnnxDataType(type_id));
for (const auto &item : list_value->value()) {
SetScalarToProto(item, tensor_proto);
if (item->isa<ValueList>()) {
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
} else {
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
}
}
*seq_string += "],";
}
}