ReduceAll and ReduceAny
This commit is contained in:
parent
c0bc9d60a1
commit
71d9399c26
|
@ -1139,6 +1139,8 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReduceAnyOrAll(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
|
@ -1626,6 +1628,75 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
|
|||
AddReduceOp(name, input_data, node_name, axes, keep_dims, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimReduceAnyOrAll(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto input_data_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
|
||||
auto input_axis = node->input(kTwoNum);
|
||||
auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
|
||||
auto reduce_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
|
||||
std::string target_node_name = "";
|
||||
if (node->IsApply(prim::kPrimReduceAny)) {
|
||||
target_node_name = "ReduceSum";
|
||||
} else if (node->IsApply(prim::kPrimReduceAll)) {
|
||||
target_node_name = "ReduceMin";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported reduce op: " << node->ToString();
|
||||
}
|
||||
|
||||
std::string cast_name = GenerateUniqueName(); // Insert cast op
|
||||
onnx::NodeProto *cast_proto = graph_proto->add_node();
|
||||
cast_proto->add_input(input_data_name);
|
||||
cast_proto->add_output(cast_name);
|
||||
cast_proto->set_op_type(prim::kPrimCast->name());
|
||||
onnx::AttributeProto *attr_proto = cast_proto->add_attribute();
|
||||
attr_proto->set_name("to");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
attr_proto->set_i(GetOnnxDataType(TypeId::kNumberTypeFloat32));
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
if (input_axis->isa<ValueNode>()) {
|
||||
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
|
||||
if (axis_value->isa<Int32Imm>()) {
|
||||
auto int_ptr = dyn_cast<Int32Imm>(axis_value);
|
||||
axes.push_back(int_ptr->value());
|
||||
} else if (axis_value->isa<Int64Imm>()) {
|
||||
auto int_ptr = dyn_cast<Int64Imm>(axis_value);
|
||||
axes.push_back(int_ptr->value());
|
||||
} else if (axis_value->isa<ValueTuple>()) {
|
||||
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
|
||||
axes = GetValue<std::vector<int64_t>>(tuple_ptr);
|
||||
if (axes.empty()) {
|
||||
const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
axes.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot convert value " << axis_value->ToString() << " of type "
|
||||
<< axis_value->type()->ToString() << " for \"axes\" attribute of " << target_node_name;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << target_node_name;
|
||||
}
|
||||
|
||||
std::string greater_name = GenerateUniqueName();
|
||||
onnx::TensorProto *zero_initializer_proto = graph_proto->add_initializer();
|
||||
auto zero_input_name = greater_name + "_zero";
|
||||
zero_initializer_proto->set_name(zero_input_name);
|
||||
zero_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
|
||||
zero_initializer_proto->add_float_data(0);
|
||||
|
||||
AddReduceOp(target_node_name, cast_name, greater_name, axes, keep_dims, graph_proto);
|
||||
|
||||
onnx::NodeProto *greater_node_proto = graph_proto->add_node(); // Insert greater op
|
||||
greater_node_proto->add_input(greater_name);
|
||||
greater_node_proto->add_input(zero_input_name);
|
||||
greater_node_proto->add_output(reduce_name);
|
||||
greater_node_proto->set_op_type(prim::kPrimGreater->name());
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
|
@ -3334,6 +3405,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimReduceMean, &OnnxExporter::ExportPrimReduce},
|
||||
{prim::kPrimReduceSum, &OnnxExporter::ExportPrimReduce},
|
||||
{prim::kPrimReduceMax, &OnnxExporter::ExportPrimReduce},
|
||||
{prim::kPrimReduceAny, &OnnxExporter::ExportPrimReduceAnyOrAll},
|
||||
{prim::kPrimReduceAll, &OnnxExporter::ExportPrimReduceAnyOrAll},
|
||||
{prim::kPrimTranspose, &OnnxExporter::ExportPrimTranspose},
|
||||
{prim::kPrimStridedSlice, &OnnxExporter::ExportPrimStridedSlice},
|
||||
{prim::kPrimResizeNearestNeighbor, &OnnxExporter::ExportPrimResizeNearestNeighbor},
|
||||
|
|
Loading…
Reference in New Issue