!29623 ONNX converter improvements: part three

Merge pull request !29623 from amalyshev/pr-onnx-converter-part-three
This commit is contained in:
i-robot 2022-02-21 09:25:51 +00:00 committed by Gitee
commit d6d7a84f5f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 214 additions and 108 deletions

View File

@ -46,6 +46,7 @@ enum OpMergeMode {
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization`
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
OP_MERGE_LAYER_NORM = 6, // indicate `MindSpore LayerNorm(x)[0]` --> `ONNX MeanVarianceNormalization`
OP_MERGE_CONV2D_TRANSPOSE = 7, // indicate `MindSpore ConvTranspose + BiasAdd` --> `ONNX ConvTranspose`
};
struct OpMergedInfo {
@ -408,6 +409,59 @@ void AddReduceOp(const std::string &op_type, const std::string &input, const std
}
}
void AddMeanVarianceNormalizationOp(const std::string &input, const std::string &gamma, const std::string &beta,
const std::string &output, const std::vector<int64_t> &axes, float epsilon,
const std::vector<int64_t> &input_shape, onnx::TensorProto_DataType input_type,
onnx::GraphProto *graph_proto) {
auto input_name = output + "_input";
AddCastOp(input, input_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
auto gamma_name = output + "_gamma";
AddCastOp(gamma, gamma_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
auto beta_name = output + "_beta";
AddCastOp(beta, beta_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
// MeanVarianceNormalization is replaced with equivalent ops because it is not supported by CUDAExecutionProvider
auto meanvariancenormal_node_name = output + "_normalized";
auto mean_name = output + "_mean";
AddReduceOp("ReduceMean", input_name, mean_name, axes, true, graph_proto);
auto centered_name = output + "_centered";
AddOp("Sub", {input_name, mean_name}, {centered_name}, graph_proto);
auto sqsum_name = output + "_sqsum";
AddReduceOp("ReduceSumSquare", centered_name, sqsum_name, axes, true, graph_proto);
float reduce_size = std::accumulate(axes.begin(), axes.end(), 1.0f,
[&input_shape](auto acc, auto axis) { return acc * input_shape[axis]; });
auto reduce_size_name = output + "_reduce_size";
AddFloatScalarInitializer(reduce_size_name, reduce_size, onnx::TensorProto_DataType_FLOAT, graph_proto);
auto variance_name = output + "_variance";
AddOp("Div", {sqsum_name, reduce_size_name}, {variance_name}, graph_proto);
auto epsilon_name = output + "_epsilon";
AddFloatScalarInitializer(epsilon_name, epsilon, onnx::TensorProto_DataType_FLOAT, graph_proto);
auto variance_with_epsilon_name = output + "_variance_with_epsilon";
AddOp("Add", {variance_name, epsilon_name}, {variance_with_epsilon_name}, graph_proto);
auto std_name = output + "_std";
AddOp("Sqrt", {variance_with_epsilon_name}, {std_name}, graph_proto);
AddOp("Div", {centered_name, std_name}, {meanvariancenormal_node_name}, graph_proto);
// Add mul and add node
auto mul_node_name = output + "_rescaled";
AddOp("Mul", {meanvariancenormal_node_name, gamma_name}, {mul_node_name}, graph_proto);
// add beta
auto add_node_name = output;
if (input_type == onnx::TensorProto_DataType_FLOAT16) {
add_node_name += "_shifted";
}
AddOp("Add", {mul_node_name, beta_name}, {add_node_name}, graph_proto);
if (input_type == onnx::TensorProto_DataType_FLOAT16) {
AddCastOp(add_node_name, output, onnx::TensorProto_DataType_FLOAT16, graph_proto);
}
}
void AddConcatOp(const std::vector<std::string> &inputs, const std::string &output, int axis,
onnx::GraphProto *graph_proto) {
onnx::NodeProto *concat_proto = graph_proto->add_node();
@ -570,9 +624,6 @@ OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze,
OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS,
SetAttrTupleValueToProto<0>))
OPERATOR_ONNX_CONVERT_DEFINE(
Conv2D, Conv,
@ -714,7 +765,6 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
@ -854,8 +904,14 @@ class OnnxExporter {
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto);
void ExportPrimConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimGreaterEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
@ -866,6 +922,8 @@ class OnnxExporter {
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeLayerNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
@ -1047,6 +1105,7 @@ void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNode
const std::vector<MergeRule> first_input_merge_rules = {
{prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV},
{prim::kPrimBiasAdd, prim::kPrimConv2DTranspose, OP_MERGE_CONV2D_TRANSPOSE},
{prim::kPrimBiasAdd, prim::kPrimConv3D, OP_MERGE_CONV},
{prim::kPrimBiasAdd, prim::kPrimConv3DTranspose, OP_MERGE_CONV},
{prim::kPrimBiasAdd, prim::kPrimMatMul, OP_MERGE_GEMM},
@ -1143,6 +1202,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
case OP_MERGE_LAYER_NORM:
ExportMergeLayerNorm(func_graph, cnode, node_map_ptr, graph_proto);
break;
case OP_MERGE_CONV2D_TRANSPOSE:
ExportMergeConv2DTranspose(func_graph, cnode, node_map_ptr, graph_proto);
break;
default:
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
break;
@ -2362,6 +2424,73 @@ void OnnxExporter::ExportPrimOneHot(const FuncGraphPtr &, const CNodePtr &node,
one_hot_axis_attr_proto->set_i(axis);
}
/*
Based on nn.Conv2dTranspose
Warning: `output_shape` is an input in MS and an attribute in ONNX. Hence
it is not possible to change the output shape in runtime
*/
void OnnxExporter::PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto node_idx = AllocateNodeIndex();
std::vector<AnfNodePtr> inputs{conv_node->input(kOneNum), conv_node->input(kTwoNum)};
if (bias_add_node != nullptr) {
inputs.push_back(bias_add_node->input(kTwoNum));
(*node_map_ptr)[bias_add_node] = node_idx;
} else {
(*node_map_ptr)[conv_node] = node_idx;
}
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("ConvTranspose");
for (const auto &input : inputs) {
node_proto->add_input(GetNodeInputName(input, node_map_ptr, graph_proto));
}
node_proto->add_output(std::to_string(node_idx));
auto prim = GetPrimitive(conv_node);
auto attrs_convert_info =
OpNameInfo()
.Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding)
.Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>);
for (const auto &attr_info : attrs_convert_info.op_attrs()) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr_info.onnx_attr_name());
auto ms_attr = GetOpAttributePtr<Value>(conv_node, attr_info.attr_name());
MS_EXCEPTION_IF_NULL(ms_attr);
attr_info.fn_gen_attr()(ms_attr, attr_info.onnx_attr_type(), attr_proto, prim);
}
// Set output shape
auto input_shape_node = GetRealInput(conv_node->input(kThreeNum));
if (!input_shape_node->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "For ONNX export third argument must be constant "
"(Python tuple). Instead got "
<< input_shape_node->ToString();
}
auto input_shape_value_ptr = input_shape_node->cast<ValueNodePtr>()->value();
if (!input_shape_value_ptr->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Expected ValueTuple, got " << input_shape_value_ptr->ToString() << " of type "
<< input_shape_value_ptr->type()->ToString();
}
onnx::AttributeProto *output_shape_attr_proto = node_proto->add_attribute();
output_shape_attr_proto->set_name("output_shape");
SetAttrTupleValueToProto<0>(input_shape_value_ptr, onnx::AttributeProto_AttributeType_INTS, output_shape_attr_proto,
prim);
}
void OnnxExporter::ExportPrimConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto) {
PrimConv2DTransposeExportHelper(node, nullptr, node_map_ptr, graph_proto);
}
void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
@ -2377,6 +2506,30 @@ void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &
AddOp("Not", {less_name}, {node_name}, graph_proto);
}
void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Squeeze");
node_proto->add_input(input_name);
node_proto->add_output(std::to_string(node_idx));
auto axes = GetOpAttributePtr<ValueSequence>(node, "axis");
auto axes_value = GetValue<std::vector<int64_t>>(axes);
if (!axes_value.empty()) {
onnx::AttributeProto *axes_proto = node_proto->add_attribute();
axes_proto->set_name("axes");
axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
for (auto axis : axes_value) {
axes_proto->add_ints(axis);
}
}
}
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
@ -2407,7 +2560,9 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike},
{prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue},
{prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot},
{prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose},
{prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual},
{prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
{prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
{prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU},
@ -2613,12 +2768,44 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
onnx::GraphProto *const graph_proto) {
auto batch_norm_node = dyn_cast<CNode>(node->input(kOneNum));
PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(kZeroNum)))->value());
std::vector<AnfNodePtr> inputs;
for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) {
inputs.push_back(batch_norm_node->input(i));
auto is_training = GetOpAttribute<bool>(batch_norm_node, "is_training");
if (is_training) {
auto input_x_name = GetNodeInputName(batch_norm_node->input(kOneNum), node_map_ptr, graph_proto);
auto scale_input_name = GetNodeInputName(batch_norm_node->input(kTwoNum), node_map_ptr, graph_proto);
auto bias_input_name = GetNodeInputName(batch_norm_node->input(kThreeNum), node_map_ptr, graph_proto);
auto onnx_type = GetOutputType(batch_norm_node->input(kOneNum));
auto output_index = AllocateNodeIndex();
auto output_name = std::to_string(output_index);
(*node_map_ptr)[node] = output_index;
auto input_shape_ptr = batch_norm_node->input(kOneNum)->Shape();
auto input_shape = input_shape_ptr->cast<abstract::ShapePtr>()->shape();
std::vector<int64_t> normalize_axes = {0};
for (size_t i = kTwoNum; i < input_shape.size(); ++i) {
normalize_axes.push_back(static_cast<int64_t>(i));
}
std::vector<int64_t> scale_bias_shape(input_shape.size(), 1);
scale_bias_shape[1] = -1;
auto reshaped_scale_name = output_name + "_reshaped_scale";
AddReshapeOp(scale_input_name, reshaped_scale_name, scale_bias_shape, graph_proto);
auto reshaped_bias_name = output_name + "_reshaped_bias";
AddReshapeOp(bias_input_name, reshaped_bias_name, scale_bias_shape, graph_proto);
auto epsilon = GetOpAttribute<float>(batch_norm_node, "epsilon");
AddMeanVarianceNormalizationOp(input_x_name, reshaped_scale_name, reshaped_bias_name, output_name, normalize_axes,
epsilon, input_shape, onnx_type, graph_proto);
} else {
PrimitivePtr prim_batch_norm = GetPrimitive(batch_norm_node);
std::vector<AnfNodePtr> inputs;
for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) {
inputs.push_back(batch_norm_node->input(i));
}
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
}
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
}
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
@ -2644,109 +2831,28 @@ void OnnxExporter::ExportMergeLayerNorm(const FuncGraphPtr &, const CNodePtr &no
auto layernorm_input_gamma = GetNodeInputName(LayerNormNode->input(kTwoNum), node_map_ptr, graph_proto);
auto layernorm_input_beta = GetNodeInputName(LayerNormNode->input(kThreeNum), node_map_ptr, graph_proto);
auto layernorm_input_x_node = LayerNormNode->input(kOneNum);
auto dtype = layernorm_input_x_node->Type();
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
size_t pre_cast_node_idx = 0;
// if type is float16, add cast node cast type from float16 to float32
if (elem_type == kNumberTypeFloat16) {
pre_cast_node_idx = AllocateNodeIndex();
AddCastOp(layernorm_input_x, std::to_string(pre_cast_node_idx), onnx::TensorProto_DataType_FLOAT, graph_proto);
auto begin_norm_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_norm_axis");
auto begin_params_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_params_axis");
if (begin_norm_axis != -1 || begin_params_axis != -1) {
MS_LOG(EXCEPTION) << "begin_norm_axis != -1 and begin_params_axis != -1 are not implemented";
}
// reshape before MeanVarianceNormalization
auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape());
std::vector<int64_t> new_input_shape;
int64_t n_shape = 1;
int64_t c_shape = 1;
int64_t h_shape = 1;
size_t input_shape_size = input_shape->shape().size();
for (size_t i = 0; i < input_shape_size - 1; i++) {
c_shape = c_shape * input_shape->shape()[i];
}
new_input_shape.push_back(n_shape);
new_input_shape.push_back(c_shape);
new_input_shape.push_back(h_shape);
new_input_shape.push_back(input_shape->shape()[input_shape_size - kOneNum]);
auto onnx_type = GetOutputType(LayerNormNode->input(kOneNum));
auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape())->shape();
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
auto epsilon = GetOpAttribute<float>(LayerNormNode, "epsilon");
std::vector<int64_t> reduce_axes = {static_cast<int64_t>(input_shape.size()) - 1};
// Add shape node for reshape(before MeanVarianceNormalization)
auto new_shape_value = MakeValue<std::vector<int64_t>>(new_input_shape);
auto shape_node = NewValueNode(new_shape_value)->cast<AnfNodePtr>();
auto shape_node_idx = AllocateNodeIndex();
AddMeanVarianceNormalizationOp(layernorm_input_x, layernorm_input_gamma, layernorm_input_beta,
std::to_string(node_idx), reduce_axes, epsilon, input_shape, onnx_type, graph_proto);
}
onnx::NodeProto *shape_node_proto = graph_proto->add_node();
shape_node_proto->add_output(std::to_string(shape_node_idx));
shape_node_proto->set_op_type("Constant");
onnx::AttributeProto *shape_attr_proto = shape_node_proto->add_attribute();
shape_attr_proto->set_name("value");
shape_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(shape_node)->value(), shape_attr_proto->mutable_t());
// Add reshape node before MeanVarianceNormalization
auto pre_reshape_node_idx = AllocateNodeIndex();
onnx::NodeProto *pre_reshape_node_proto = graph_proto->add_node();
pre_reshape_node_proto->set_op_type("Reshape");
pre_reshape_node_proto->add_output(std::to_string(pre_reshape_node_idx));
if (elem_type == kNumberTypeFloat16) {
pre_reshape_node_proto->add_input(std::to_string(pre_cast_node_idx));
} else {
pre_reshape_node_proto->add_input(layernorm_input_x);
}
pre_reshape_node_proto->add_input(std::to_string(shape_node_idx));
// MeanVarianceNormalization
auto meanvariancenormal_node_idx = AllocateNodeIndex();
onnx::NodeProto *meanvariancenormal_node_proto = graph_proto->add_node();
meanvariancenormal_node_proto->set_op_type("MeanVarianceNormalization");
meanvariancenormal_node_proto->add_output(std::to_string(meanvariancenormal_node_idx));
meanvariancenormal_node_proto->add_input(std::to_string(pre_reshape_node_idx));
// if cast type from float16 to float32, add cast node cast type from float32 to float16
size_t aft_cast_node_idx = 0;
if (elem_type == kNumberTypeFloat16) {
aft_cast_node_idx = AllocateNodeIndex();
AddCastOp(std::to_string(meanvariancenormal_node_idx), std::to_string(aft_cast_node_idx),
onnx::TensorProto_DataType_FLOAT16, graph_proto);
}
// Add mul and add node
auto mul_node_idx = AllocateNodeIndex();
onnx::NodeProto *mul_node_proto = graph_proto->add_node();
mul_node_proto->set_op_type("Mul");
if (elem_type == kNumberTypeFloat16) {
mul_node_proto->add_input(std::to_string(aft_cast_node_idx));
} else {
mul_node_proto->add_input(std::to_string(meanvariancenormal_node_idx));
}
mul_node_proto->add_input(layernorm_input_gamma);
mul_node_proto->add_output(std::to_string(mul_node_idx));
// add beta
auto add_node_idx = AllocateNodeIndex();
AddOp("Add", {std::to_string(mul_node_idx), layernorm_input_beta}, {std::to_string(add_node_idx)}, graph_proto);
// reshape after MeanVarianceNormalization
// Add shape node for reshape(after MeanVarianceNormalization)
auto output_shape_value = MakeValue<std::vector<int64_t>>(input_shape->shape());
auto output_shape_node = NewValueNode(output_shape_value)->cast<AnfNodePtr>();
auto output_shape_node_idx = AllocateNodeIndex();
onnx::NodeProto *output_shape_node_proto = graph_proto->add_node();
output_shape_node_proto->add_output(std::to_string(output_shape_node_idx));
output_shape_node_proto->set_op_type("Constant");
onnx::AttributeProto *output_shape_attr_proto = output_shape_node_proto->add_attribute();
output_shape_attr_proto->set_name("value");
output_shape_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(output_shape_node)->value(), output_shape_attr_proto->mutable_t());
// Add reshape node after MeanVarianceNormalization
auto aft_reshape_node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = aft_reshape_node_idx;
onnx::NodeProto *aft_reshape_node_proto = graph_proto->add_node();
aft_reshape_node_proto->set_op_type("Reshape");
aft_reshape_node_proto->add_output(std::to_string(aft_reshape_node_idx));
aft_reshape_node_proto->add_input(std::to_string(add_node_idx));
aft_reshape_node_proto->add_input(std::to_string(output_shape_node_idx));
void OnnxExporter::ExportMergeConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto conv_node = dyn_cast<CNode>(node->input(kOneNum));
PrimConv2DTransposeExportHelper(conv_node, node, node_map_ptr, graph_proto);
}
/*