forked from mindspore-Ecosystem/mindspore
!29623 ONNX converter improvements: part three
Merge pull request !29623 from amalyshev/pr-onnx-converter-part-three
This commit is contained in:
commit
d6d7a84f5f
|
@ -46,6 +46,7 @@ enum OpMergeMode {
|
|||
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization`
|
||||
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
|
||||
OP_MERGE_LAYER_NORM = 6, // indicate `MindSpore LayerNorm(x)[0]` --> `ONNX MeanVarianceNormalization`
|
||||
OP_MERGE_CONV2D_TRANSPOSE = 7, // indicate `MindSpore ConvTranspose + BiasAdd` --> `ONNX ConvTranspose`
|
||||
};
|
||||
|
||||
struct OpMergedInfo {
|
||||
|
@ -408,6 +409,59 @@ void AddReduceOp(const std::string &op_type, const std::string &input, const std
|
|||
}
|
||||
}
|
||||
|
||||
void AddMeanVarianceNormalizationOp(const std::string &input, const std::string &gamma, const std::string &beta,
|
||||
const std::string &output, const std::vector<int64_t> &axes, float epsilon,
|
||||
const std::vector<int64_t> &input_shape, onnx::TensorProto_DataType input_type,
|
||||
onnx::GraphProto *graph_proto) {
|
||||
auto input_name = output + "_input";
|
||||
AddCastOp(input, input_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
auto gamma_name = output + "_gamma";
|
||||
AddCastOp(gamma, gamma_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
auto beta_name = output + "_beta";
|
||||
AddCastOp(beta, beta_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
|
||||
// MeanVarianceNormalization is replaced with equivalent ops because it is not supported by CUDAExecutionProvider
|
||||
auto meanvariancenormal_node_name = output + "_normalized";
|
||||
|
||||
auto mean_name = output + "_mean";
|
||||
AddReduceOp("ReduceMean", input_name, mean_name, axes, true, graph_proto);
|
||||
auto centered_name = output + "_centered";
|
||||
AddOp("Sub", {input_name, mean_name}, {centered_name}, graph_proto);
|
||||
|
||||
auto sqsum_name = output + "_sqsum";
|
||||
AddReduceOp("ReduceSumSquare", centered_name, sqsum_name, axes, true, graph_proto);
|
||||
float reduce_size = std::accumulate(axes.begin(), axes.end(), 1.0f,
|
||||
[&input_shape](auto acc, auto axis) { return acc * input_shape[axis]; });
|
||||
auto reduce_size_name = output + "_reduce_size";
|
||||
AddFloatScalarInitializer(reduce_size_name, reduce_size, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
auto variance_name = output + "_variance";
|
||||
AddOp("Div", {sqsum_name, reduce_size_name}, {variance_name}, graph_proto);
|
||||
|
||||
auto epsilon_name = output + "_epsilon";
|
||||
AddFloatScalarInitializer(epsilon_name, epsilon, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
auto variance_with_epsilon_name = output + "_variance_with_epsilon";
|
||||
AddOp("Add", {variance_name, epsilon_name}, {variance_with_epsilon_name}, graph_proto);
|
||||
auto std_name = output + "_std";
|
||||
AddOp("Sqrt", {variance_with_epsilon_name}, {std_name}, graph_proto);
|
||||
|
||||
AddOp("Div", {centered_name, std_name}, {meanvariancenormal_node_name}, graph_proto);
|
||||
|
||||
// Add mul and add node
|
||||
auto mul_node_name = output + "_rescaled";
|
||||
AddOp("Mul", {meanvariancenormal_node_name, gamma_name}, {mul_node_name}, graph_proto);
|
||||
|
||||
// add beta
|
||||
auto add_node_name = output;
|
||||
if (input_type == onnx::TensorProto_DataType_FLOAT16) {
|
||||
add_node_name += "_shifted";
|
||||
}
|
||||
AddOp("Add", {mul_node_name, beta_name}, {add_node_name}, graph_proto);
|
||||
|
||||
if (input_type == onnx::TensorProto_DataType_FLOAT16) {
|
||||
AddCastOp(add_node_name, output, onnx::TensorProto_DataType_FLOAT16, graph_proto);
|
||||
}
|
||||
}
|
||||
|
||||
void AddConcatOp(const std::vector<std::string> &inputs, const std::string &output, int axis,
|
||||
onnx::GraphProto *graph_proto) {
|
||||
onnx::NodeProto *concat_proto = graph_proto->add_node();
|
||||
|
@ -570,9 +624,6 @@ OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
|
|||
OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo())
|
||||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze,
|
||||
OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS,
|
||||
SetAttrTupleValueToProto<0>))
|
||||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(
|
||||
Conv2D, Conv,
|
||||
|
@ -714,7 +765,6 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
|||
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
|
||||
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
|
||||
|
@ -854,8 +904,14 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto);
|
||||
void ExportPrimConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimGreaterEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
|
@ -866,6 +922,8 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportMergeLayerNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportMergeConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
|
||||
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
|
@ -1047,6 +1105,7 @@ void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNode
|
|||
|
||||
const std::vector<MergeRule> first_input_merge_rules = {
|
||||
{prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV},
|
||||
{prim::kPrimBiasAdd, prim::kPrimConv2DTranspose, OP_MERGE_CONV2D_TRANSPOSE},
|
||||
{prim::kPrimBiasAdd, prim::kPrimConv3D, OP_MERGE_CONV},
|
||||
{prim::kPrimBiasAdd, prim::kPrimConv3DTranspose, OP_MERGE_CONV},
|
||||
{prim::kPrimBiasAdd, prim::kPrimMatMul, OP_MERGE_GEMM},
|
||||
|
@ -1143,6 +1202,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
|
|||
case OP_MERGE_LAYER_NORM:
|
||||
ExportMergeLayerNorm(func_graph, cnode, node_map_ptr, graph_proto);
|
||||
break;
|
||||
case OP_MERGE_CONV2D_TRANSPOSE:
|
||||
ExportMergeConv2DTranspose(func_graph, cnode, node_map_ptr, graph_proto);
|
||||
break;
|
||||
default:
|
||||
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
|
||||
break;
|
||||
|
@ -2362,6 +2424,73 @@ void OnnxExporter::ExportPrimOneHot(const FuncGraphPtr &, const CNodePtr &node,
|
|||
one_hot_axis_attr_proto->set_i(axis);
|
||||
}
|
||||
|
||||
/*
|
||||
Based on nn.Conv2dTranspose
|
||||
Warning: `output_shape` is an input in MS and an attribute in ONNX. Hence
|
||||
it is not possible to change the output shape in runtime
|
||||
*/
|
||||
void OnnxExporter::PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
|
||||
std::vector<AnfNodePtr> inputs{conv_node->input(kOneNum), conv_node->input(kTwoNum)};
|
||||
if (bias_add_node != nullptr) {
|
||||
inputs.push_back(bias_add_node->input(kTwoNum));
|
||||
(*node_map_ptr)[bias_add_node] = node_idx;
|
||||
} else {
|
||||
(*node_map_ptr)[conv_node] = node_idx;
|
||||
}
|
||||
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("ConvTranspose");
|
||||
for (const auto &input : inputs) {
|
||||
node_proto->add_input(GetNodeInputName(input, node_map_ptr, graph_proto));
|
||||
}
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
|
||||
auto prim = GetPrimitive(conv_node);
|
||||
auto attrs_convert_info =
|
||||
OpNameInfo()
|
||||
.Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
|
||||
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
|
||||
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
|
||||
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding)
|
||||
.Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>);
|
||||
for (const auto &attr_info : attrs_convert_info.op_attrs()) {
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name(attr_info.onnx_attr_name());
|
||||
auto ms_attr = GetOpAttributePtr<Value>(conv_node, attr_info.attr_name());
|
||||
MS_EXCEPTION_IF_NULL(ms_attr);
|
||||
attr_info.fn_gen_attr()(ms_attr, attr_info.onnx_attr_type(), attr_proto, prim);
|
||||
}
|
||||
|
||||
// Set output shape
|
||||
|
||||
auto input_shape_node = GetRealInput(conv_node->input(kThreeNum));
|
||||
if (!input_shape_node->isa<ValueNode>()) {
|
||||
MS_LOG(EXCEPTION) << "For ONNX export third argument must be constant "
|
||||
"(Python tuple). Instead got "
|
||||
<< input_shape_node->ToString();
|
||||
}
|
||||
auto input_shape_value_ptr = input_shape_node->cast<ValueNodePtr>()->value();
|
||||
if (!input_shape_value_ptr->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Expected ValueTuple, got " << input_shape_value_ptr->ToString() << " of type "
|
||||
<< input_shape_value_ptr->type()->ToString();
|
||||
}
|
||||
|
||||
onnx::AttributeProto *output_shape_attr_proto = node_proto->add_attribute();
|
||||
output_shape_attr_proto->set_name("output_shape");
|
||||
SetAttrTupleValueToProto<0>(input_shape_value_ptr, onnx::AttributeProto_AttributeType_INTS, output_shape_attr_proto,
|
||||
prim);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto) {
|
||||
PrimConv2DTransposeExportHelper(node, nullptr, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
|
@ -2377,6 +2506,30 @@ void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &
|
|||
AddOp("Not", {less_name}, {node_name}, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
|
||||
auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
|
||||
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Squeeze");
|
||||
node_proto->add_input(input_name);
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
|
||||
auto axes = GetOpAttributePtr<ValueSequence>(node, "axis");
|
||||
auto axes_value = GetValue<std::vector<int64_t>>(axes);
|
||||
if (!axes_value.empty()) {
|
||||
onnx::AttributeProto *axes_proto = node_proto->add_attribute();
|
||||
axes_proto->set_name("axes");
|
||||
axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
||||
for (auto axis : axes_value) {
|
||||
axes_proto->add_ints(axis);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
|
||||
|
@ -2407,7 +2560,9 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike},
|
||||
{prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue},
|
||||
{prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot},
|
||||
{prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose},
|
||||
{prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual},
|
||||
{prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
|
||||
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
|
||||
{prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
|
||||
{prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU},
|
||||
|
@ -2613,12 +2768,44 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
|
|||
onnx::GraphProto *const graph_proto) {
|
||||
auto batch_norm_node = dyn_cast<CNode>(node->input(kOneNum));
|
||||
|
||||
PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(kZeroNum)))->value());
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) {
|
||||
inputs.push_back(batch_norm_node->input(i));
|
||||
auto is_training = GetOpAttribute<bool>(batch_norm_node, "is_training");
|
||||
if (is_training) {
|
||||
auto input_x_name = GetNodeInputName(batch_norm_node->input(kOneNum), node_map_ptr, graph_proto);
|
||||
auto scale_input_name = GetNodeInputName(batch_norm_node->input(kTwoNum), node_map_ptr, graph_proto);
|
||||
auto bias_input_name = GetNodeInputName(batch_norm_node->input(kThreeNum), node_map_ptr, graph_proto);
|
||||
|
||||
auto onnx_type = GetOutputType(batch_norm_node->input(kOneNum));
|
||||
|
||||
auto output_index = AllocateNodeIndex();
|
||||
auto output_name = std::to_string(output_index);
|
||||
(*node_map_ptr)[node] = output_index;
|
||||
|
||||
auto input_shape_ptr = batch_norm_node->input(kOneNum)->Shape();
|
||||
auto input_shape = input_shape_ptr->cast<abstract::ShapePtr>()->shape();
|
||||
|
||||
std::vector<int64_t> normalize_axes = {0};
|
||||
for (size_t i = kTwoNum; i < input_shape.size(); ++i) {
|
||||
normalize_axes.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
|
||||
std::vector<int64_t> scale_bias_shape(input_shape.size(), 1);
|
||||
scale_bias_shape[1] = -1;
|
||||
auto reshaped_scale_name = output_name + "_reshaped_scale";
|
||||
AddReshapeOp(scale_input_name, reshaped_scale_name, scale_bias_shape, graph_proto);
|
||||
auto reshaped_bias_name = output_name + "_reshaped_bias";
|
||||
AddReshapeOp(bias_input_name, reshaped_bias_name, scale_bias_shape, graph_proto);
|
||||
auto epsilon = GetOpAttribute<float>(batch_norm_node, "epsilon");
|
||||
|
||||
AddMeanVarianceNormalizationOp(input_x_name, reshaped_scale_name, reshaped_bias_name, output_name, normalize_axes,
|
||||
epsilon, input_shape, onnx_type, graph_proto);
|
||||
} else {
|
||||
PrimitivePtr prim_batch_norm = GetPrimitive(batch_norm_node);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) {
|
||||
inputs.push_back(batch_norm_node->input(i));
|
||||
}
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
||||
}
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
|
@ -2644,109 +2831,28 @@ void OnnxExporter::ExportMergeLayerNorm(const FuncGraphPtr &, const CNodePtr &no
|
|||
auto layernorm_input_gamma = GetNodeInputName(LayerNormNode->input(kTwoNum), node_map_ptr, graph_proto);
|
||||
auto layernorm_input_beta = GetNodeInputName(LayerNormNode->input(kThreeNum), node_map_ptr, graph_proto);
|
||||
|
||||
auto layernorm_input_x_node = LayerNormNode->input(kOneNum);
|
||||
auto dtype = layernorm_input_x_node->Type();
|
||||
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||
size_t pre_cast_node_idx = 0;
|
||||
|
||||
// if type is float16, add cast node cast type from float16 to float32
|
||||
if (elem_type == kNumberTypeFloat16) {
|
||||
pre_cast_node_idx = AllocateNodeIndex();
|
||||
AddCastOp(layernorm_input_x, std::to_string(pre_cast_node_idx), onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
auto begin_norm_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_norm_axis");
|
||||
auto begin_params_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_params_axis");
|
||||
if (begin_norm_axis != -1 || begin_params_axis != -1) {
|
||||
MS_LOG(EXCEPTION) << "begin_norm_axis != -1 and begin_params_axis != -1 are not implemented";
|
||||
}
|
||||
|
||||
// reshape before MeanVarianceNormalization
|
||||
auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape());
|
||||
std::vector<int64_t> new_input_shape;
|
||||
int64_t n_shape = 1;
|
||||
int64_t c_shape = 1;
|
||||
int64_t h_shape = 1;
|
||||
size_t input_shape_size = input_shape->shape().size();
|
||||
for (size_t i = 0; i < input_shape_size - 1; i++) {
|
||||
c_shape = c_shape * input_shape->shape()[i];
|
||||
}
|
||||
new_input_shape.push_back(n_shape);
|
||||
new_input_shape.push_back(c_shape);
|
||||
new_input_shape.push_back(h_shape);
|
||||
new_input_shape.push_back(input_shape->shape()[input_shape_size - kOneNum]);
|
||||
auto onnx_type = GetOutputType(LayerNormNode->input(kOneNum));
|
||||
auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape())->shape();
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
auto epsilon = GetOpAttribute<float>(LayerNormNode, "epsilon");
|
||||
std::vector<int64_t> reduce_axes = {static_cast<int64_t>(input_shape.size()) - 1};
|
||||
|
||||
// Add shape node for reshape(before MeanVarianceNormalization)
|
||||
auto new_shape_value = MakeValue<std::vector<int64_t>>(new_input_shape);
|
||||
auto shape_node = NewValueNode(new_shape_value)->cast<AnfNodePtr>();
|
||||
auto shape_node_idx = AllocateNodeIndex();
|
||||
AddMeanVarianceNormalizationOp(layernorm_input_x, layernorm_input_gamma, layernorm_input_beta,
|
||||
std::to_string(node_idx), reduce_axes, epsilon, input_shape, onnx_type, graph_proto);
|
||||
}
|
||||
|
||||
onnx::NodeProto *shape_node_proto = graph_proto->add_node();
|
||||
shape_node_proto->add_output(std::to_string(shape_node_idx));
|
||||
shape_node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *shape_attr_proto = shape_node_proto->add_attribute();
|
||||
shape_attr_proto->set_name("value");
|
||||
shape_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(shape_node)->value(), shape_attr_proto->mutable_t());
|
||||
|
||||
// Add reshape node before MeanVarianceNormalization
|
||||
auto pre_reshape_node_idx = AllocateNodeIndex();
|
||||
onnx::NodeProto *pre_reshape_node_proto = graph_proto->add_node();
|
||||
pre_reshape_node_proto->set_op_type("Reshape");
|
||||
pre_reshape_node_proto->add_output(std::to_string(pre_reshape_node_idx));
|
||||
if (elem_type == kNumberTypeFloat16) {
|
||||
pre_reshape_node_proto->add_input(std::to_string(pre_cast_node_idx));
|
||||
} else {
|
||||
pre_reshape_node_proto->add_input(layernorm_input_x);
|
||||
}
|
||||
pre_reshape_node_proto->add_input(std::to_string(shape_node_idx));
|
||||
|
||||
// MeanVarianceNormalization
|
||||
auto meanvariancenormal_node_idx = AllocateNodeIndex();
|
||||
onnx::NodeProto *meanvariancenormal_node_proto = graph_proto->add_node();
|
||||
meanvariancenormal_node_proto->set_op_type("MeanVarianceNormalization");
|
||||
meanvariancenormal_node_proto->add_output(std::to_string(meanvariancenormal_node_idx));
|
||||
meanvariancenormal_node_proto->add_input(std::to_string(pre_reshape_node_idx));
|
||||
|
||||
// if cast type from float16 to float32, add cast node cast type from float32 to float16
|
||||
size_t aft_cast_node_idx = 0;
|
||||
if (elem_type == kNumberTypeFloat16) {
|
||||
aft_cast_node_idx = AllocateNodeIndex();
|
||||
AddCastOp(std::to_string(meanvariancenormal_node_idx), std::to_string(aft_cast_node_idx),
|
||||
onnx::TensorProto_DataType_FLOAT16, graph_proto);
|
||||
}
|
||||
|
||||
// Add mul and add node
|
||||
auto mul_node_idx = AllocateNodeIndex();
|
||||
onnx::NodeProto *mul_node_proto = graph_proto->add_node();
|
||||
mul_node_proto->set_op_type("Mul");
|
||||
if (elem_type == kNumberTypeFloat16) {
|
||||
mul_node_proto->add_input(std::to_string(aft_cast_node_idx));
|
||||
} else {
|
||||
mul_node_proto->add_input(std::to_string(meanvariancenormal_node_idx));
|
||||
}
|
||||
mul_node_proto->add_input(layernorm_input_gamma);
|
||||
mul_node_proto->add_output(std::to_string(mul_node_idx));
|
||||
|
||||
// add beta
|
||||
auto add_node_idx = AllocateNodeIndex();
|
||||
AddOp("Add", {std::to_string(mul_node_idx), layernorm_input_beta}, {std::to_string(add_node_idx)}, graph_proto);
|
||||
|
||||
// reshape after MeanVarianceNormalization
|
||||
// Add shape node for reshape(after MeanVarianceNormalization)
|
||||
auto output_shape_value = MakeValue<std::vector<int64_t>>(input_shape->shape());
|
||||
auto output_shape_node = NewValueNode(output_shape_value)->cast<AnfNodePtr>();
|
||||
auto output_shape_node_idx = AllocateNodeIndex();
|
||||
|
||||
onnx::NodeProto *output_shape_node_proto = graph_proto->add_node();
|
||||
output_shape_node_proto->add_output(std::to_string(output_shape_node_idx));
|
||||
output_shape_node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *output_shape_attr_proto = output_shape_node_proto->add_attribute();
|
||||
output_shape_attr_proto->set_name("value");
|
||||
output_shape_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(output_shape_node)->value(), output_shape_attr_proto->mutable_t());
|
||||
// Add reshape node after MeanVarianceNormalization
|
||||
auto aft_reshape_node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = aft_reshape_node_idx;
|
||||
onnx::NodeProto *aft_reshape_node_proto = graph_proto->add_node();
|
||||
aft_reshape_node_proto->set_op_type("Reshape");
|
||||
aft_reshape_node_proto->add_output(std::to_string(aft_reshape_node_idx));
|
||||
aft_reshape_node_proto->add_input(std::to_string(add_node_idx));
|
||||
aft_reshape_node_proto->add_input(std::to_string(output_shape_node_idx));
|
||||
void OnnxExporter::ExportMergeConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto conv_node = dyn_cast<CNode>(node->input(kOneNum));
|
||||
PrimConv2DTransposeExportHelper(conv_node, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
Loading…
Reference in New Issue