!11190 Modify register of TBE ops.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@chujinjin
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-01-13 10:16:44 +08:00 committed by Gitee
commit 593a80fae5
12 changed files with 5 additions and 43 deletions

View File

@ -30,8 +30,8 @@ lrn_op_info = TBERegOp("LRN") \
.attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

View File

@ -31,8 +31,8 @@ lrn_grad_op_info = TBERegOp("LRNGrad") \
.input(1, "x", False, "required", "all") \
.input(2, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

View File

@ -50,27 +50,17 @@ parallel_concat_op_info = TBERegOp("ParallelConcat") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()

View File

@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \
.input(0, "x", False, "required", "all") \
.input(1, "weight", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

View File

@ -28,8 +28,6 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \
.input(2, "weights", False, "required", "all") \
.output(0, "dx", False, "required", "all") \
.output(0, "da", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_Default,
DataType.F32_NCHW, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,

View File

@ -32,8 +32,6 @@ sparse_apply_adagrad_d_op_info = TBERegOp("SparseApplyAdagrad") \
.input(3, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,

View File

@ -33,8 +33,6 @@ sparse_apply_adagrad_v2_d_op_info = TBERegOp("SparseApplyAdagradV2") \
.input(3, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,

View File

@ -36,14 +36,10 @@ sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.I64_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.I64_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,

View File

@ -37,8 +37,6 @@ sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,

View File

@ -37,8 +37,6 @@ sparse_apply_ftrl_v2_d_op_info = TBERegOp("SparseApplyFtrlV2") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,

View File

@ -33,9 +33,6 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad")
.input(6, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
@ -48,9 +45,6 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad")
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD,
DataType.F32_5HD) \

View File

@ -34,9 +34,6 @@ sparse_apply_proximal_adagrad_d_ds_op_info = TBERegOp("SparseApplyProximalAdagra
.input(6, "indices", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
@ -49,9 +46,6 @@ sparse_apply_proximal_adagrad_d_ds_op_info = TBERegOp("SparseApplyProximalAdagra
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD,
DataType.F32_5HD) \