forked from mindspore-Ecosystem/mindspore
!3322 revert anf export modify
Merge pull request !3322 from dinghao/master
This commit is contained in:
commit
78443578a5
|
@ -92,17 +92,14 @@ class IrExportBuilder {
|
|||
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
|
||||
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
|
||||
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
|
||||
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto,
|
||||
std::string suffix = "0");
|
||||
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name);
|
||||
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto);
|
||||
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
|
||||
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
|
||||
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
|
||||
|
@ -110,10 +107,8 @@ class IrExportBuilder {
|
|||
std::string GetNodeName(const AnfNodePtr &node);
|
||||
std::string GetUniqueNodeName(const AnfNodePtr &node);
|
||||
std::string GetOpTypeName(const AnfNodePtr &node);
|
||||
size_t GetNodeIndex() { return ++node_index_; }
|
||||
void ResetNodeIndex() { node_index_ = 0; }
|
||||
size_t GetTupleIndex() { return ++shape_index_; }
|
||||
void ResetTupleIndex() { shape_index_ = 0; }
|
||||
size_t AllocateIndex() { return ++node_index_; }
|
||||
void ResetIndex() { node_index_ = 0; }
|
||||
|
||||
private:
|
||||
onnx::ModelProto model_;
|
||||
|
@ -121,7 +116,6 @@ class IrExportBuilder {
|
|||
std::list<FuncGraphPtr> todo_;
|
||||
std::map<AnfNodePtr, size_t> node_index_map_;
|
||||
size_t node_index_{0};
|
||||
size_t shape_index_{0};
|
||||
};
|
||||
|
||||
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
||||
|
@ -154,7 +148,7 @@ void IrExportBuilder::BuildModelInfo() {
|
|||
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
||||
onnx::GraphProto *graph_proto = model_.mutable_graph();
|
||||
graph_proto->set_name(func_graph->ToString());
|
||||
ResetNodeIndex();
|
||||
ResetIndex();
|
||||
todo_.clear();
|
||||
todo_.push_back(func_graph);
|
||||
while (!todo_.empty()) {
|
||||
|
@ -185,7 +179,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
|
|||
input_proto->set_name(param_name);
|
||||
SetValueInfoProto(param, input_proto);
|
||||
if (!param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default.";
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default";
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -240,20 +234,13 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr
|
|||
auto elem_type = tensor->element();
|
||||
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
|
||||
if (dims.size() == 0) {
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1.";
|
||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
} else {
|
||||
for (const auto &dim : dims) {
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
|
||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
||||
}
|
||||
for (const auto &dim : dims) {
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
|
||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
||||
}
|
||||
} else if (type->isa<Tuple>()) {
|
||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||
type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
} else if (type->isa<Number>() || type->isa<String>()) {
|
||||
type_proto->set_denotation(type->type_name());
|
||||
type_proto->set_denotation(std::to_string(tup_shape->shape().size()));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
||||
}
|
||||
|
@ -263,10 +250,9 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att
|
|||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
attr_proto->set_ref_attr_name("tensor:value0");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name("value0");
|
||||
attr_proto->set_ref_attr_name("tensor");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
|
||||
auto dtype = data->data_type();
|
||||
|
@ -300,7 +286,6 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten
|
|||
|
||||
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||
bool is_only_return = true;
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
|
||||
|
@ -308,13 +293,9 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == func_graph->get_return()) {
|
||||
if (is_only_return) {
|
||||
MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!";
|
||||
}
|
||||
BuildOutput(cnode, graph_proto);
|
||||
} else {
|
||||
BuildCNode(cnode, graph_proto);
|
||||
is_only_return = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -324,11 +305,24 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const
|
|||
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
||||
}
|
||||
AnfNodePtr arg = node->input(1);
|
||||
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
std::string output_name = GetUniqueNodeName(node);
|
||||
output_proto->set_name(output_name);
|
||||
last_node_->set_output(0, output_name);
|
||||
SetValueInfoProto(arg, output_proto);
|
||||
// Using make_tuple to set multi-output
|
||||
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
|
||||
auto tuple_node = arg->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < tuple_node->size(); i++) {
|
||||
auto input_node = arg->cast<CNodePtr>()->input(i);
|
||||
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
auto output_name = GetUniqueNodeName(tuple_node->input(i));
|
||||
output_proto->set_name(output_name);
|
||||
last_node_->add_output(output_name);
|
||||
SetValueInfoProto(tuple_node->input(i), output_proto);
|
||||
}
|
||||
} else {
|
||||
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
std::string output_name = GetUniqueNodeName(node);
|
||||
output_proto->set_name(output_name);
|
||||
last_node_->add_output(output_name);
|
||||
SetValueInfoProto(arg, output_proto);
|
||||
}
|
||||
}
|
||||
|
||||
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
||||
|
@ -351,41 +345,43 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
onnx::AttributeProto *const attr_proto, std::string *const seq_string) {
|
||||
if (type->isa<Tuple>() && seq_string != nullptr) {
|
||||
*seq_string += "Tuple[";
|
||||
auto elements = type->cast<TuplePtr>()->elements();
|
||||
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
|
||||
}
|
||||
*seq_string += "],";
|
||||
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) {
|
||||
string shape_name = "shape" + std::to_string(GetTupleIndex());
|
||||
*seq_string += shape_name + ",";
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name(shape_name);
|
||||
SetTensorProto(type, shape, tensor_proto);
|
||||
} else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) {
|
||||
*seq_string += type->type_name() + ",";
|
||||
onnx::NodeProto *const node_proto, std::string suffix) {
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_ref_attr_name("shape");
|
||||
if (suffix.compare("0") != 0) {
|
||||
attr_proto->set_name("shape" + suffix);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
|
||||
attr_proto->set_name("shape");
|
||||
}
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
SetTensorProto(type, shape, tensor_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
|
||||
// Get shape of cnode
|
||||
// 1. need to get shape from tuple element
|
||||
// 2. save shape in TensorProto
|
||||
// 3. save tuple string in ref_attr_name
|
||||
auto type = node->Type();
|
||||
auto shape = node->Shape();
|
||||
ResetTupleIndex();
|
||||
std::string seq_string = "shape:";
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
|
||||
attr_proto->set_ref_attr_name(seq_string);
|
||||
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
|
||||
// 1. prim ArgMaxWithValue need to get shape from tuple element
|
||||
// 2. some cnode doesn't has shape, such as LayerNorm
|
||||
// 3. other cnodes have shape
|
||||
if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) {
|
||||
auto type = node->Type();
|
||||
auto shape = node->Shape();
|
||||
if (!type->isa<Tuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name();
|
||||
}
|
||||
auto elements = type->cast<TuplePtr>()->elements();
|
||||
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i));
|
||||
}
|
||||
} else {
|
||||
auto type = node->Type();
|
||||
auto shape = node->Shape();
|
||||
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
||||
MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString();
|
||||
return;
|
||||
}
|
||||
SetShapeToNodeProto(type, shape, node_proto);
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
|
||||
|
@ -449,19 +445,15 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
|
|||
std::string node_name = "";
|
||||
if (node->isa<Parameter>()) {
|
||||
node_name = GetNodeName(node);
|
||||
} else if (node->isa<CNode>()) {
|
||||
} else if (node->isa<CNode>() || node->isa<ValueNode>()) {
|
||||
auto iter = node_index_map_.find(node);
|
||||
if (iter != node_index_map_.end()) {
|
||||
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
|
||||
} else {
|
||||
auto node_idx = GetNodeIndex();
|
||||
auto node_idx = AllocateIndex();
|
||||
node_index_map_[node] = node_idx;
|
||||
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
||||
}
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
auto node_idx = GetNodeIndex();
|
||||
node_index_map_[node] = node_idx;
|
||||
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
|
||||
}
|
||||
|
@ -495,21 +487,17 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri
|
|||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
attr_proto->set_ref_attr_name("type");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
if (value->isa<Int>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
auto int_value = value->cast<IntPtr>();
|
||||
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
|
||||
} else if (value->isa<Float>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
auto float_value = value->cast<FloatPtr>();
|
||||
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
|
||||
} else if (value->isa<TensorType>()) {
|
||||
attr_proto->set_ref_attr_name("type:tensor0");
|
||||
tensor_proto->set_name("tensor0");
|
||||
tensor_proto->set_name("tensor");
|
||||
auto elem_type = value->cast<TensorTypePtr>()->element();
|
||||
if (elem_type->isa<Int>()) {
|
||||
auto int_value = elem_type->cast<IntPtr>();
|
||||
|
@ -533,18 +521,10 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr
|
|||
SetScalarToAttributeProto(value, attr_proto);
|
||||
} else if (value->isa<Number>() || value->isa<TensorType>()) {
|
||||
SetTypeToAttributeProto(value, attr_proto);
|
||||
} else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) {
|
||||
ResetTupleIndex();
|
||||
std::string seq_string = "scalar:";
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
|
||||
attr_proto->set_ref_attr_name(seq_string);
|
||||
MS_LOG(DEBUG) << "Attr string: " << seq_string;
|
||||
} else if (value->isa<ValueSequeue>()) {
|
||||
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto);
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
SetTensorToAttributeProto(value, attr_proto);
|
||||
} else if (value->isa<None>()) {
|
||||
attr_proto->set_ref_attr_name("none");
|
||||
MS_LOG(DEBUG) << "Attr string: " << value->type_name();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
||||
}
|
||||
|
@ -554,18 +534,16 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att
|
|||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
attr_proto->set_ref_attr_name("scalar:value0");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
SetScalarToProto(value, tensor_proto, "value0");
|
||||
attr_proto->set_ref_attr_name("scalar");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
SetScalarToProto(value, tensor_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto,
|
||||
const std::string &value_name) {
|
||||
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
|
||||
if (value == nullptr || tensor_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
|
||||
}
|
||||
tensor_proto->set_name(value_name);
|
||||
if (value->isa<StringImm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
|
||||
tensor_proto->add_string_data(GetValue<std::string>(value));
|
||||
|
@ -584,74 +562,44 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto
|
|||
} else if (value->isa<Int64Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
|
||||
} else if (value->isa<UInt8Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8);
|
||||
tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value());
|
||||
} else if (value->isa<UInt16Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16);
|
||||
tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value());
|
||||
} else if (value->isa<UInt32Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32);
|
||||
tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value());
|
||||
} else if (value->isa<UInt64Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64);
|
||||
tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value());
|
||||
} else if (value->isa<FP32Imm>()) {
|
||||
} else if (value->isa<FloatImm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
|
||||
tensor_proto->add_float_data(GetValue<float>(value));
|
||||
} else if (value->isa<FP64Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
|
||||
tensor_proto->add_double_data(GetValue<double>(value));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
string value_name = "value" + std::to_string(GetTupleIndex());
|
||||
if (seq_string != nullptr) {
|
||||
*seq_string += value_name + ",";
|
||||
}
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
SetScalarToProto(value, tensor_proto, value_name);
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
||||
onnx::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
|
||||
}
|
||||
if (value->isa<ValueTuple>() && seq_string != nullptr) {
|
||||
*seq_string += "Tuple[";
|
||||
attr_proto->set_ref_attr_name("scalar");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
if (value->isa<ValueTuple>()) {
|
||||
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
|
||||
if (tuple_value->value().size() == 0) {
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
|
||||
return;
|
||||
}
|
||||
auto type_id = tuple_value->value()[0]->type()->type_id();
|
||||
tensor_proto->set_data_type(GetOnnxDataType(type_id));
|
||||
for (const auto &item : tuple_value->value()) {
|
||||
if (item->isa<ValueTuple>()) {
|
||||
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
|
||||
} else {
|
||||
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
|
||||
}
|
||||
SetScalarToProto(item, tensor_proto);
|
||||
}
|
||||
*seq_string += "],";
|
||||
} else if (value->isa<ValueList>() && seq_string != nullptr) {
|
||||
*seq_string += "List[";
|
||||
} else if (value->isa<ValueList>()) {
|
||||
const ValueListPtr &list_value = value->cast<ValueListPtr>();
|
||||
if (list_value->value().size() == 0) {
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0";
|
||||
return;
|
||||
}
|
||||
auto type_id = list_value->value()[0]->type()->type_id();
|
||||
tensor_proto->set_data_type(GetOnnxDataType(type_id));
|
||||
for (const auto &item : list_value->value()) {
|
||||
if (item->isa<ValueList>()) {
|
||||
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
|
||||
} else {
|
||||
SetSeqElemToAttributeProto(item, attr_proto, seq_string);
|
||||
}
|
||||
SetScalarToProto(item, tensor_proto);
|
||||
}
|
||||
*seq_string += "],";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue