forked from mindspore-Ecosystem/mindspore
!17524 Add onnx export
From: @liuyang_655 Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxian,@zh_qh
This commit is contained in:
commit
28848a97b9
|
@ -262,10 +262,16 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
|||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(StridedSlice, Slice, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(ResizeNearestNeighbor, Resize, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo())
|
||||
|
||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||
|
||||
|
@ -289,10 +295,14 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
|||
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
|
||||
|
||||
fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Concat)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Sub)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Maximum)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Exp)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(ResizeNearestNeighbor)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Softplus)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Tanh)());
|
||||
}
|
||||
|
||||
class OpConvertRegistry {
|
||||
|
@ -351,6 +361,14 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
|
@ -592,12 +610,25 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
|
|||
|
||||
std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos;
|
||||
MatchAndMark(func_graph, nodes, &op_merged_infos);
|
||||
|
||||
int count = -1;
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
count++;
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->IsApply(prim::kPrimMakeTuple)) {
|
||||
int i = count + 1;
|
||||
while (!nodes[i]->isa<CNode>()) {
|
||||
i++;
|
||||
}
|
||||
auto nextCNode = nodes[i]->cast<CNodePtr>();
|
||||
if (nextCNode->IsApply(prim::kPrimUpdateState) &&
|
||||
IsPrimitiveCNode(nextCNode->input(2), std::make_shared<Primitive>("MakeTuple"))) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto iter = op_merged_infos.find(cnode);
|
||||
// the node is not referenced by any other nodes, skip it
|
||||
if (iter == op_merged_infos.end()) {
|
||||
|
@ -700,6 +731,224 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
|
|||
}
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto input_perm = node->input(2);
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
auto name = prim::kPrimTranspose->name();
|
||||
node_proto->set_op_type(name);
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(input_data);
|
||||
|
||||
if (input_perm->isa<ValueNode>()) {
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("perm");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
||||
auto perm_value = dyn_cast<ValueNode>(input_perm)->value();
|
||||
auto int_ptr = dyn_cast<Int32Imm>(perm_value);
|
||||
if (int_ptr == nullptr) {
|
||||
auto tuple_ptr = dyn_cast<ValueTuple>(perm_value);
|
||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
|
||||
attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
|
||||
}
|
||||
} else {
|
||||
attr_proto->add_ints(int_ptr->value());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The input input_perm of Transpose is not a ValueNode! "
|
||||
<< "Need to insert op convert variable from tuple to attributes for " << name;
|
||||
}
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto begin = node->input(2);
|
||||
auto name = prim::kPrimStridedSlice->name();
|
||||
std::string name_begin;
|
||||
if (begin->isa<ValueNode>()) {
|
||||
auto const_node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[begin] = const_node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
name_begin = std::to_string(const_node_idx);
|
||||
node_proto->add_output(name_begin);
|
||||
|
||||
node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("starts");
|
||||
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(begin)->value(), attr_proto->mutable_t());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The input begin of StridedSlice is not a ValueNode! "
|
||||
<< "Need to insert op convert variable from tuple to tensor for " << name;
|
||||
}
|
||||
|
||||
auto end = node->input(3);
|
||||
std::string name_end;
|
||||
if (end->isa<ValueNode>()) {
|
||||
auto const_node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[end] = const_node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
name_end = std::to_string(const_node_idx);
|
||||
node_proto->add_output(name_end);
|
||||
|
||||
node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("ends");
|
||||
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(end)->value(), attr_proto->mutable_t());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The input end of StridedSlice is not a ValueNode! "
|
||||
<< "Need to insert op convert variable from tuple to tensor for " << name;
|
||||
}
|
||||
|
||||
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
|
||||
int size = x_shape->shape().size();
|
||||
std::vector<int32_t> axes_value;
|
||||
ValuePtr axes_value_ptr = nullptr;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
axes_value.push_back(i);
|
||||
}
|
||||
axes_value_ptr = MakeValue<std::vector<int32_t>>(axes_value);
|
||||
auto axes = NewValueNode(axes_value_ptr)->cast<AnfNodePtr>();
|
||||
std::string name_axes;
|
||||
auto const_node_idx_axes = AllocateNodeIndex();
|
||||
(*node_map_ptr)[axes] = const_node_idx_axes;
|
||||
onnx::NodeProto *node_proto_axes = graph_proto->add_node();
|
||||
name_axes = std::to_string(const_node_idx_axes);
|
||||
node_proto_axes->add_output(name_axes);
|
||||
node_proto_axes->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto_axes = node_proto_axes->add_attribute();
|
||||
attr_proto_axes->set_name("axes");
|
||||
attr_proto_axes->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(axes)->value(), attr_proto_axes->mutable_t());
|
||||
|
||||
auto strides = node->input(4);
|
||||
std::string name_strides;
|
||||
if (strides->isa<ValueNode>()) {
|
||||
auto const_node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[strides] = const_node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
name_strides = std::to_string(const_node_idx);
|
||||
node_proto->add_output(name_strides);
|
||||
|
||||
node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto_steps = node_proto->add_attribute();
|
||||
attr_proto_steps->set_name("steps");
|
||||
attr_proto_steps->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(strides)->value(), attr_proto_steps->mutable_t());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The input strides of StridedSlice is not a ValueNode! "
|
||||
<< "Need to insert op convert variable from tuple to tensor for " << name;
|
||||
}
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Slice");
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(input_data);
|
||||
node_proto->add_input(name_begin);
|
||||
node_proto->add_input(name_end);
|
||||
node_proto->add_input(name_axes);
|
||||
node_proto->add_input(name_strides);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
|
||||
|
||||
AnfNodePtr op = node->input(0);
|
||||
auto op_value = dyn_cast<ValueNode>(op);
|
||||
auto prim = dyn_cast<Primitive>(op_value->value());
|
||||
std::vector<int64_t> resize_size;
|
||||
|
||||
auto tuple_ptr = dyn_cast<ValueTuple>(prim->GetAttr("size"));
|
||||
|
||||
for (size_t i = 0; i < x_shape->shape().size() - 2; i++) {
|
||||
resize_size.push_back(x_shape->shape()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < tuple_ptr->size(); i++) {
|
||||
ValuePtr elem = (*tuple_ptr)[i];
|
||||
resize_size.push_back(dyn_cast<Int64Imm>(elem)->value());
|
||||
}
|
||||
auto resize_size_ptr = MakeValue<std::vector<int64_t>>(resize_size);
|
||||
auto size = NewValueNode(resize_size_ptr)->cast<AnfNodePtr>();
|
||||
std::string name_size;
|
||||
|
||||
auto const_node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[size] = const_node_idx;
|
||||
onnx::NodeProto *node_proto_size = graph_proto->add_node();
|
||||
name_size = std::to_string(const_node_idx);
|
||||
node_proto_size->add_output(name_size);
|
||||
node_proto_size->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto_size->add_attribute();
|
||||
attr_proto->set_name("sizes");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t());
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
|
||||
onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer();
|
||||
auto roi_name = std::to_string(node_idx) + "roi_initializer";
|
||||
roi_initializer_proto->set_name(roi_name);
|
||||
roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
|
||||
roi_initializer_proto->add_dims(0);
|
||||
|
||||
onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer();
|
||||
auto scales_name = std::to_string(node_idx) + "scales_initializer";
|
||||
scales_initializer_proto->set_name(scales_name);
|
||||
scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
|
||||
scales_initializer_proto->add_dims(0);
|
||||
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
|
||||
node_proto->set_op_type("Resize");
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(input_data);
|
||||
node_proto->add_input(roi_name);
|
||||
node_proto->add_input(scales_name);
|
||||
node_proto->add_input(name_size);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
|
||||
AnfNodePtr op = node->input(0);
|
||||
auto op_value = dyn_cast<ValueNode>(op);
|
||||
auto prim = dyn_cast<Primitive>(op_value->value());
|
||||
auto input_node = node->input(1)->cast<CNodePtr>();
|
||||
|
||||
if (input_node->IsApply(prim::kPrimMakeTuple)) {
|
||||
node_proto->set_op_type("ConcatFromSequence");
|
||||
} else {
|
||||
node_proto->set_op_type("Concat");
|
||||
}
|
||||
|
||||
// set attr axis
|
||||
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
|
||||
onnx_attr_proto->set_name("axis");
|
||||
SetAttrValueToProto<Int64Imm>(prim->GetAttr("axis"), onnx::AttributeProto_AttributeType_INT, onnx_attr_proto, prim);
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(input_data);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimCast(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
|
@ -959,6 +1208,19 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
if (node->IsApply(prim::kPrimTranspose)) {
|
||||
return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
if (node->IsApply(prim::kPrimStridedSlice)) {
|
||||
return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
if (node->IsApply(prim::kPrimResizeNearestNeighbor)) {
|
||||
return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
if (node->IsApply(prim::kPrimConcat)) {
|
||||
return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
|
||||
if (node->IsApply(prim::kPrimCast)) {
|
||||
return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto);
|
||||
|
@ -1016,9 +1278,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name();
|
||||
}
|
||||
|
||||
if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
|
||||
}
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
|
||||
}
|
||||
|
||||
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr &, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
|
|
|
@ -214,6 +214,7 @@ inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("Re
|
|||
inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
|
||||
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
|
||||
inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad");
|
||||
inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared<Primitive>("ResizeNearestNeighbor");
|
||||
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
|
||||
inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect");
|
||||
|
||||
|
|
Loading…
Reference in New Issue