forked from mindspore-Ecosystem/mindspore
!45681 可以导出onnx的cos和atan2
Merge pull request !45681 from wangtongyu6/export_onnx_cos_atan2
This commit is contained in:
commit
951f3d48e5
|
@ -545,6 +545,21 @@ void ClipPointsComponent(const std::string &points, const std::string &clipped,
|
|||
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
|
||||
}
|
||||
|
||||
// check AnfNode data type is float or not.
|
||||
bool IsFloatDataType(const AnfNodePtr &node) {
|
||||
auto dtype = node->Type();
|
||||
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||
switch (elem_type) {
|
||||
case (kNumberTypeFloat):
|
||||
case (kNumberTypeFloat16):
|
||||
case (kNumberTypeFloat32):
|
||||
case (kNumberTypeFloat64):
|
||||
return True;
|
||||
default:
|
||||
return False;
|
||||
}
|
||||
}
|
||||
|
||||
namespace while_loop_export {
|
||||
namespace {
|
||||
const char CONTROL_PATTERN[] = "\u21B5";
|
||||
|
@ -1025,6 +1040,8 @@ OPERATOR_ONNX_CONVERT_DEFINE(Less, Less, OpNameInfo())
|
|||
OPERATOR_ONNX_CONVERT_DEFINE(TensorScatterUpdate, ScatterND,
|
||||
OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
|
||||
onnx::TensorProto_DataType_INT64))
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Cos, Cos, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Atan2, Atan2, OpNameInfo())
|
||||
|
||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||
|
||||
|
@ -1074,7 +1091,10 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
|||
fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(TensorScatterUpdate)());
|
||||
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Sin)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Cos)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Atan2)());
|
||||
}
|
||||
|
||||
class OpConvertRegistry {
|
||||
|
@ -1219,6 +1239,8 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
|
@ -3325,6 +3347,45 @@ void OnnxExporter::ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node,
|
|||
new_axis_proto->set_i(true);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
|
||||
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
auto input_node1_anf = node->input(kOneNum);
|
||||
auto input_node2_anf = node->input(kTwoNum);
|
||||
auto input_node1 = GetNodeInputName(input_node1_anf, node_map_ptr, graph_proto);
|
||||
auto input_node2 = GetNodeInputName(input_node2_anf, node_map_ptr, graph_proto);
|
||||
auto atan_node = "Atan2_" + node_name + "_atan";
|
||||
auto div_node = "Atan2_" + node_name + "_div";
|
||||
auto less_node = "Atan2_" + node_name + "_less";
|
||||
auto zero_value = "Atan2_" + node_name + "_zero";
|
||||
auto neg_pi_value = "Atan2_" + node_name + "_pi";
|
||||
auto minimal_value = "Atan2_" + node_name + "_minimal_val";
|
||||
auto sign_node = "Atan2_" + node_name + "_sign";
|
||||
auto mul_node = "Atan2_" + node_name + "_mul";
|
||||
auto less_where_node1 = "Atan2_" + node_name + "_less_then_else1";
|
||||
auto add_node = "Atan2_" + node_name + "_add1";
|
||||
if (!(IsFloatDataType(input_node1_anf) && IsFloatDataType(input_node2_anf))) {
|
||||
auto input_node1_cast = node_name + "_div_cast_fp32_1";
|
||||
auto input_node2_cast = node_name + "_div_cast_fp32_2";
|
||||
AddCastOp(input_node1, input_node1_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
AddCastOp(input_node2, input_node2_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
input_node1 = input_node1_cast;
|
||||
input_node2 = input_node2_cast;
|
||||
}
|
||||
AddFloatScalarInitializer(minimal_value, 1e-10, onnx::TensorProto_DataType_FLOAT,
|
||||
graph_proto); // minimal_value, avoid division by zero
|
||||
AddOp("Add", {input_node2, minimal_value}, {add_node}, graph_proto);
|
||||
AddOp("Div", {input_node1, add_node}, {div_node}, graph_proto);
|
||||
AddOp("Atan", {div_node}, {atan_node}, graph_proto);
|
||||
AddFloatScalarInitializer(zero_value, 0, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||
AddOp("Less", {input_node2, zero_value}, {less_node}, graph_proto);
|
||||
AddFloatScalarInitializer(neg_pi_value, -acos(-1), onnx::TensorProto_DataType_FLOAT, graph_proto); // -PI
|
||||
AddOp("Sign", {atan_node}, {sign_node}, graph_proto);
|
||||
AddOp("Mul", {neg_pi_value, sign_node}, {mul_node}, graph_proto);
|
||||
AddOp("Where", {less_node, mul_node, zero_value}, {less_where_node1}, graph_proto);
|
||||
AddOp("Add", {less_where_node1, atan_node}, {node_name}, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
|
||||
|
@ -3371,6 +3432,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
|
||||
{prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
|
||||
{prim::kPrimStack, &OnnxExporter::ExportPrimStack},
|
||||
{prim::kPrimAtan2, &OnnxExporter::ExportPrimAtan2},
|
||||
};
|
||||
|
||||
auto iter = std::find_if(export_table.begin(), export_table.end(),
|
||||
|
|
Loading…
Reference in New Issue