!45681 可以导出onnx的cos和atan2

Merge pull request !45681 from wangtongyu6/export_onnx_cos_atan2
This commit is contained in:
i-robot 2022-11-18 02:44:05 +00:00 committed by Gitee
commit 951f3d48e5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 62 additions and 0 deletions

View File

@ -545,6 +545,21 @@ void ClipPointsComponent(const std::string &points, const std::string &clipped,
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
}
// check AnfNode data type is float or not.
bool IsFloatDataType(const AnfNodePtr &node) {
auto dtype = node->Type();
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
switch (elem_type) {
case (kNumberTypeFloat):
case (kNumberTypeFloat16):
case (kNumberTypeFloat32):
case (kNumberTypeFloat64):
return True;
default:
return False;
}
}
namespace while_loop_export {
namespace {
const char CONTROL_PATTERN[] = "\u21B5";
@ -1025,6 +1040,8 @@ OPERATOR_ONNX_CONVERT_DEFINE(Less, Less, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(TensorScatterUpdate, ScatterND,
OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
onnx::TensorProto_DataType_INT64))
OPERATOR_ONNX_CONVERT_DEFINE(Cos, Cos, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Atan2, Atan2, OpNameInfo())
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
@ -1074,7 +1091,10 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
fn(OP_CONVERT_FUNCTION_NAME(TensorScatterUpdate)());
fn(OP_CONVERT_FUNCTION_NAME(Sin)());
fn(OP_CONVERT_FUNCTION_NAME(Cos)());
fn(OP_CONVERT_FUNCTION_NAME(Atan2)());
}
class OpConvertRegistry {
@ -1219,6 +1239,8 @@ class OnnxExporter {
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
@ -3325,6 +3347,45 @@ void OnnxExporter::ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node,
new_axis_proto->set_i(true);
}
void OnnxExporter::ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
auto input_node1_anf = node->input(kOneNum);
auto input_node2_anf = node->input(kTwoNum);
auto input_node1 = GetNodeInputName(input_node1_anf, node_map_ptr, graph_proto);
auto input_node2 = GetNodeInputName(input_node2_anf, node_map_ptr, graph_proto);
auto atan_node = "Atan2_" + node_name + "_atan";
auto div_node = "Atan2_" + node_name + "_div";
auto less_node = "Atan2_" + node_name + "_less";
auto zero_value = "Atan2_" + node_name + "_zero";
auto neg_pi_value = "Atan2_" + node_name + "_pi";
auto minimal_value = "Atan2_" + node_name + "_minimal_val";
auto sign_node = "Atan2_" + node_name + "_sign";
auto mul_node = "Atan2_" + node_name + "_mul";
auto less_where_node1 = "Atan2_" + node_name + "_less_then_else1";
auto add_node = "Atan2_" + node_name + "_add1";
if (!(IsFloatDataType(input_node1_anf) && IsFloatDataType(input_node2_anf))) {
auto input_node1_cast = node_name + "_div_cast_fp32_1";
auto input_node2_cast = node_name + "_div_cast_fp32_2";
AddCastOp(input_node1, input_node1_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
AddCastOp(input_node2, input_node2_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
input_node1 = input_node1_cast;
input_node2 = input_node2_cast;
}
AddFloatScalarInitializer(minimal_value, 1e-10, onnx::TensorProto_DataType_FLOAT,
graph_proto); // minimal_value, avoid division by zero
AddOp("Add", {input_node2, minimal_value}, {add_node}, graph_proto);
AddOp("Div", {input_node1, add_node}, {div_node}, graph_proto);
AddOp("Atan", {div_node}, {atan_node}, graph_proto);
AddFloatScalarInitializer(zero_value, 0, onnx::TensorProto_DataType_FLOAT, graph_proto);
AddOp("Less", {input_node2, zero_value}, {less_node}, graph_proto);
AddFloatScalarInitializer(neg_pi_value, -acos(-1), onnx::TensorProto_DataType_FLOAT, graph_proto); // -PI
AddOp("Sign", {atan_node}, {sign_node}, graph_proto);
AddOp("Mul", {neg_pi_value, sign_node}, {mul_node}, graph_proto);
AddOp("Where", {less_node, mul_node, zero_value}, {less_where_node1}, graph_proto);
AddOp("Add", {less_where_node1, atan_node}, {node_name}, graph_proto);
}
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
@ -3371,6 +3432,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
{prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
{prim::kPrimStack, &OnnxExporter::ExportPrimStack},
{prim::kPrimAtan2, &OnnxExporter::ExportPrimAtan2},
};
auto iter = std::find_if(export_table.begin(), export_table.end(),