ReduceAll and ReduceAny

This commit is contained in:
zhaoyingzhuo 2022-11-18 09:45:05 +08:00
parent c0bc9d60a1
commit 71d9399c26
1 changed files with 73 additions and 0 deletions

View File

@ -1139,6 +1139,8 @@ class OnnxExporter {
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduceAnyOrAll(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
@ -1626,6 +1628,75 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
AddReduceOp(name, input_data, node_name, axes, keep_dims, graph_proto);
}
void OnnxExporter::ExportPrimReduceAnyOrAll(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
auto input_axis = node->input(kTwoNum);
auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
auto reduce_name = RegisterNodeWithUniqueName(node, node_map_ptr);
std::string target_node_name = "";
if (node->IsApply(prim::kPrimReduceAny)) {
target_node_name = "ReduceSum";
} else if (node->IsApply(prim::kPrimReduceAll)) {
target_node_name = "ReduceMin";
} else {
MS_LOG(EXCEPTION) << "Unsupported reduce op: " << node->ToString();
}
std::string cast_name = GenerateUniqueName(); // Insert cast op
onnx::NodeProto *cast_proto = graph_proto->add_node();
cast_proto->add_input(input_data_name);
cast_proto->add_output(cast_name);
cast_proto->set_op_type(prim::kPrimCast->name());
onnx::AttributeProto *attr_proto = cast_proto->add_attribute();
attr_proto->set_name("to");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
attr_proto->set_i(GetOnnxDataType(TypeId::kNumberTypeFloat32));
std::vector<int64_t> axes;
if (input_axis->isa<ValueNode>()) {
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
if (axis_value->isa<Int32Imm>()) {
auto int_ptr = dyn_cast<Int32Imm>(axis_value);
axes.push_back(int_ptr->value());
} else if (axis_value->isa<Int64Imm>()) {
auto int_ptr = dyn_cast<Int64Imm>(axis_value);
axes.push_back(int_ptr->value());
} else if (axis_value->isa<ValueTuple>()) {
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
axes = GetValue<std::vector<int64_t>>(tuple_ptr);
if (axes.empty()) {
const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
for (size_t i = 0; i < x_shape.size(); ++i) {
axes.push_back(static_cast<int64_t>(i));
}
}
} else {
MS_LOG(EXCEPTION) << "Cannot convert value " << axis_value->ToString() << " of type "
<< axis_value->type()->ToString() << " for \"axes\" attribute of " << target_node_name;
}
} else {
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << target_node_name;
}
std::string greater_name = GenerateUniqueName();
onnx::TensorProto *zero_initializer_proto = graph_proto->add_initializer();
auto zero_input_name = greater_name + "_zero";
zero_initializer_proto->set_name(zero_input_name);
zero_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
zero_initializer_proto->add_float_data(0);
AddReduceOp(target_node_name, cast_name, greater_name, axes, keep_dims, graph_proto);
onnx::NodeProto *greater_node_proto = graph_proto->add_node(); // Insert greater op
greater_node_proto->add_input(greater_name);
greater_node_proto->add_input(zero_input_name);
greater_node_proto->add_output(reduce_name);
greater_node_proto->set_op_type(prim::kPrimGreater->name());
}
void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
@ -3334,6 +3405,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimReduceMean, &OnnxExporter::ExportPrimReduce},
{prim::kPrimReduceSum, &OnnxExporter::ExportPrimReduce},
{prim::kPrimReduceMax, &OnnxExporter::ExportPrimReduce},
{prim::kPrimReduceAny, &OnnxExporter::ExportPrimReduceAnyOrAll},
{prim::kPrimReduceAll, &OnnxExporter::ExportPrimReduceAnyOrAll},
{prim::kPrimTranspose, &OnnxExporter::ExportPrimTranspose},
{prim::kPrimStridedSlice, &OnnxExporter::ExportPrimStridedSlice},
{prim::kPrimResizeNearestNeighbor, &OnnxExporter::ExportPrimResizeNearestNeighbor},