!17524 Add onnx export

From: @liuyang_655
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian,@zh_qh
This commit is contained in:
mindspore-ci-bot 2021-06-02 19:52:53 +08:00 committed by Gitee
commit 28848a97b9
2 changed files with 267 additions and 6 deletions

View File

@ -262,10 +262,16 @@ OPERATOR_ONNX_CONVERT_DEFINE(
OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(StridedSlice, Slice, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ResizeNearestNeighbor, Resize, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo())
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
@ -289,10 +295,14 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
fn(OP_CONVERT_FUNCTION_NAME(Concat)());
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
fn(OP_CONVERT_FUNCTION_NAME(Sub)());
fn(OP_CONVERT_FUNCTION_NAME(Maximum)());
fn(OP_CONVERT_FUNCTION_NAME(Exp)());
fn(OP_CONVERT_FUNCTION_NAME(ResizeNearestNeighbor)());
fn(OP_CONVERT_FUNCTION_NAME(Softplus)());
fn(OP_CONVERT_FUNCTION_NAME(Tanh)());
}
class OpConvertRegistry {
@ -351,6 +361,14 @@ class OnnxExporter {
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
@ -592,12 +610,25 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos;
MatchAndMark(func_graph, nodes, &op_merged_infos);
int count = -1;
for (const AnfNodePtr &node : nodes) {
count++;
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->IsApply(prim::kPrimMakeTuple)) {
int i = count + 1;
while (!nodes[i]->isa<CNode>()) {
i++;
}
auto nextCNode = nodes[i]->cast<CNodePtr>();
if (nextCNode->IsApply(prim::kPrimUpdateState) &&
IsPrimitiveCNode(nextCNode->input(2), std::make_shared<Primitive>("MakeTuple"))) {
continue;
}
}
auto iter = op_merged_infos.find(cnode);
// the node is not referenced by any other nodes, skip it
if (iter == op_merged_infos.end()) {
@ -700,6 +731,224 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
}
}
void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_perm = node->input(2);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
auto name = prim::kPrimTranspose->name();
node_proto->set_op_type(name);
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
if (input_perm->isa<ValueNode>()) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("perm");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
auto perm_value = dyn_cast<ValueNode>(input_perm)->value();
auto int_ptr = dyn_cast<Int32Imm>(perm_value);
if (int_ptr == nullptr) {
auto tuple_ptr = dyn_cast<ValueTuple>(perm_value);
MS_EXCEPTION_IF_NULL(tuple_ptr);
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
}
} else {
attr_proto->add_ints(int_ptr->value());
}
} else {
MS_LOG(EXCEPTION) << "The input input_perm of Transpose is not a ValueNode! "
<< "Need to insert op convert variable from tuple to attributes for " << name;
}
}
void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto begin = node->input(2);
auto name = prim::kPrimStridedSlice->name();
std::string name_begin;
if (begin->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[begin] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_begin = std::to_string(const_node_idx);
node_proto->add_output(name_begin);
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("starts");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(begin)->value(), attr_proto->mutable_t());
} else {
MS_LOG(EXCEPTION) << "The input begin of StridedSlice is not a ValueNode! "
<< "Need to insert op convert variable from tuple to tensor for " << name;
}
auto end = node->input(3);
std::string name_end;
if (end->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[end] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_end = std::to_string(const_node_idx);
node_proto->add_output(name_end);
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("ends");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(end)->value(), attr_proto->mutable_t());
} else {
MS_LOG(EXCEPTION) << "The input end of StridedSlice is not a ValueNode! "
<< "Need to insert op convert variable from tuple to tensor for " << name;
}
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
int size = x_shape->shape().size();
std::vector<int32_t> axes_value;
ValuePtr axes_value_ptr = nullptr;
for (int i = 0; i < size; ++i) {
axes_value.push_back(i);
}
axes_value_ptr = MakeValue<std::vector<int32_t>>(axes_value);
auto axes = NewValueNode(axes_value_ptr)->cast<AnfNodePtr>();
std::string name_axes;
auto const_node_idx_axes = AllocateNodeIndex();
(*node_map_ptr)[axes] = const_node_idx_axes;
onnx::NodeProto *node_proto_axes = graph_proto->add_node();
name_axes = std::to_string(const_node_idx_axes);
node_proto_axes->add_output(name_axes);
node_proto_axes->set_op_type("Constant");
onnx::AttributeProto *attr_proto_axes = node_proto_axes->add_attribute();
attr_proto_axes->set_name("axes");
attr_proto_axes->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(axes)->value(), attr_proto_axes->mutable_t());
auto strides = node->input(4);
std::string name_strides;
if (strides->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[strides] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_strides = std::to_string(const_node_idx);
node_proto->add_output(name_strides);
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto_steps = node_proto->add_attribute();
attr_proto_steps->set_name("steps");
attr_proto_steps->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(strides)->value(), attr_proto_steps->mutable_t());
} else {
MS_LOG(EXCEPTION) << "The input strides of StridedSlice is not a ValueNode! "
<< "Need to insert op convert variable from tuple to tensor for " << name;
}
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Slice");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
node_proto->add_input(name_begin);
node_proto->add_input(name_end);
node_proto->add_input(name_axes);
node_proto->add_input(name_strides);
}
void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
AnfNodePtr op = node->input(0);
auto op_value = dyn_cast<ValueNode>(op);
auto prim = dyn_cast<Primitive>(op_value->value());
std::vector<int64_t> resize_size;
auto tuple_ptr = dyn_cast<ValueTuple>(prim->GetAttr("size"));
for (size_t i = 0; i < x_shape->shape().size() - 2; i++) {
resize_size.push_back(x_shape->shape()[i]);
}
for (size_t i = 0; i < tuple_ptr->size(); i++) {
ValuePtr elem = (*tuple_ptr)[i];
resize_size.push_back(dyn_cast<Int64Imm>(elem)->value());
}
auto resize_size_ptr = MakeValue<std::vector<int64_t>>(resize_size);
auto size = NewValueNode(resize_size_ptr)->cast<AnfNodePtr>();
std::string name_size;
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[size] = const_node_idx;
onnx::NodeProto *node_proto_size = graph_proto->add_node();
name_size = std::to_string(const_node_idx);
node_proto_size->add_output(name_size);
node_proto_size->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto_size->add_attribute();
attr_proto->set_name("sizes");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t());
auto node_idx = AllocateNodeIndex();
onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer();
auto roi_name = std::to_string(node_idx) + "roi_initializer";
roi_initializer_proto->set_name(roi_name);
roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
roi_initializer_proto->add_dims(0);
onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer();
auto scales_name = std::to_string(node_idx) + "scales_initializer";
scales_initializer_proto->set_name(scales_name);
scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
scales_initializer_proto->add_dims(0);
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Resize");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
node_proto->add_input(roi_name);
node_proto->add_input(scales_name);
node_proto->add_input(name_size);
}
void OnnxExporter::ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
AnfNodePtr op = node->input(0);
auto op_value = dyn_cast<ValueNode>(op);
auto prim = dyn_cast<Primitive>(op_value->value());
auto input_node = node->input(1)->cast<CNodePtr>();
if (input_node->IsApply(prim::kPrimMakeTuple)) {
node_proto->set_op_type("ConcatFromSequence");
} else {
node_proto->set_op_type("Concat");
}
// set attr axis
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name("axis");
SetAttrValueToProto<Int64Imm>(prim->GetAttr("axis"), onnx::AttributeProto_AttributeType_INT, onnx_attr_proto, prim);
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
}
void OnnxExporter::ExportPrimCast(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
@ -959,6 +1208,19 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimTranspose)) {
return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimStridedSlice)) {
return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimResizeNearestNeighbor)) {
return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimConcat)) {
return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
if (node->IsApply(prim::kPrimCast)) {
return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto);
@ -1016,9 +1278,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name();
}
if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
}
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
}
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr &, std::map<AnfNodePtr, size_t> *node_map_ptr,

View File

@ -214,6 +214,7 @@ inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("Re
inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad");
inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared<Primitive>("ResizeNearestNeighbor");
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect");