export onnx cos,atan2
Signed-off-by: TonyWang222 <wangtongyu6@huawei.com>
This commit is contained in:
parent
8454e41253
commit
6637c75196
|
@ -545,6 +545,21 @@ void ClipPointsComponent(const std::string &points, const std::string &clipped,
|
||||||
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
|
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check AnfNode data type is float or not.
|
||||||
|
bool IsFloatDataType(const AnfNodePtr &node) {
|
||||||
|
auto dtype = node->Type();
|
||||||
|
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||||
|
switch (elem_type) {
|
||||||
|
case (kNumberTypeFloat):
|
||||||
|
case (kNumberTypeFloat16):
|
||||||
|
case (kNumberTypeFloat32):
|
||||||
|
case (kNumberTypeFloat64):
|
||||||
|
return True;
|
||||||
|
default:
|
||||||
|
return False;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace while_loop_export {
|
namespace while_loop_export {
|
||||||
namespace {
|
namespace {
|
||||||
const char CONTROL_PATTERN[] = "\u21B5";
|
const char CONTROL_PATTERN[] = "\u21B5";
|
||||||
|
@ -1025,6 +1040,8 @@ OPERATOR_ONNX_CONVERT_DEFINE(Less, Less, OpNameInfo())
|
||||||
OPERATOR_ONNX_CONVERT_DEFINE(TensorScatterUpdate, ScatterND,
|
OPERATOR_ONNX_CONVERT_DEFINE(TensorScatterUpdate, ScatterND,
|
||||||
OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
|
OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
|
||||||
onnx::TensorProto_DataType_INT64))
|
onnx::TensorProto_DataType_INT64))
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(Cos, Cos, OpNameInfo())
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(Atan2, Atan2, OpNameInfo())
|
||||||
|
|
||||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||||
|
|
||||||
|
@ -1074,7 +1091,10 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
|
fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
|
fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(TensorScatterUpdate)());
|
fn(OP_CONVERT_FUNCTION_NAME(TensorScatterUpdate)());
|
||||||
|
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(Sin)());
|
fn(OP_CONVERT_FUNCTION_NAME(Sin)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(Cos)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(Atan2)());
|
||||||
}
|
}
|
||||||
|
|
||||||
class OpConvertRegistry {
|
class OpConvertRegistry {
|
||||||
|
@ -1217,6 +1237,8 @@ class OnnxExporter {
|
||||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
void ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
void ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||||
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
@ -3214,6 +3236,45 @@ void OnnxExporter::ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node,
|
||||||
new_axis_proto->set_i(true);
|
new_axis_proto->set_i(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
|
||||||
|
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||||
|
auto input_node1_anf = node->input(kOneNum);
|
||||||
|
auto input_node2_anf = node->input(kTwoNum);
|
||||||
|
auto input_node1 = GetNodeInputName(input_node1_anf, node_map_ptr, graph_proto);
|
||||||
|
auto input_node2 = GetNodeInputName(input_node2_anf, node_map_ptr, graph_proto);
|
||||||
|
auto atan_node = "Atan2_" + node_name + "_atan";
|
||||||
|
auto div_node = "Atan2_" + node_name + "_div";
|
||||||
|
auto less_node = "Atan2_" + node_name + "_less";
|
||||||
|
auto zero_value = "Atan2_" + node_name + "_zero";
|
||||||
|
auto neg_pi_value = "Atan2_" + node_name + "_pi";
|
||||||
|
auto minimal_value = "Atan2_" + node_name + "_minimal_val";
|
||||||
|
auto sign_node = "Atan2_" + node_name + "_sign";
|
||||||
|
auto mul_node = "Atan2_" + node_name + "_mul";
|
||||||
|
auto less_where_node1 = "Atan2_" + node_name + "_less_then_else1";
|
||||||
|
auto add_node = "Atan2_" + node_name + "_add1";
|
||||||
|
if (!(IsFloatDataType(input_node1_anf) && IsFloatDataType(input_node2_anf))) {
|
||||||
|
auto input_node1_cast = node_name + "_div_cast_fp32_1";
|
||||||
|
auto input_node2_cast = node_name + "_div_cast_fp32_2";
|
||||||
|
AddCastOp(input_node1, input_node1_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||||
|
AddCastOp(input_node2, input_node2_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||||
|
input_node1 = input_node1_cast;
|
||||||
|
input_node2 = input_node2_cast;
|
||||||
|
}
|
||||||
|
AddFloatScalarInitializer(minimal_value, 1e-10, onnx::TensorProto_DataType_FLOAT,
|
||||||
|
graph_proto); // minimal_value, avoid division by zero
|
||||||
|
AddOp("Add", {input_node2, minimal_value}, {add_node}, graph_proto);
|
||||||
|
AddOp("Div", {input_node1, add_node}, {div_node}, graph_proto);
|
||||||
|
AddOp("Atan", {div_node}, {atan_node}, graph_proto);
|
||||||
|
AddFloatScalarInitializer(zero_value, 0, onnx::TensorProto_DataType_FLOAT, graph_proto);
|
||||||
|
AddOp("Less", {input_node2, zero_value}, {less_node}, graph_proto);
|
||||||
|
AddFloatScalarInitializer(neg_pi_value, -acos(-1), onnx::TensorProto_DataType_FLOAT, graph_proto); // -PI
|
||||||
|
AddOp("Sign", {atan_node}, {sign_node}, graph_proto);
|
||||||
|
AddOp("Mul", {neg_pi_value, sign_node}, {mul_node}, graph_proto);
|
||||||
|
AddOp("Where", {less_node, mul_node, zero_value}, {less_where_node1}, graph_proto);
|
||||||
|
AddOp("Add", {less_where_node1, atan_node}, {node_name}, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
|
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
|
||||||
|
@ -3259,6 +3320,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
||||||
{prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
|
{prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
|
||||||
{prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
|
{prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
|
||||||
{prim::kPrimStack, &OnnxExporter::ExportPrimStack},
|
{prim::kPrimStack, &OnnxExporter::ExportPrimStack},
|
||||||
|
{prim::kPrimAtan2, &OnnxExporter::ExportPrimAtan2},
|
||||||
};
|
};
|
||||||
|
|
||||||
auto iter = std::find_if(export_table.begin(), export_table.end(),
|
auto iter = std::find_if(export_table.begin(), export_table.end(),
|
||||||
|
|
Loading…
Reference in New Issue