forked from mindspore-Ecosystem/mindspore
!11190 Modify register of TBE ops.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@chujinjin Signed-off-by: @liangchenghui
This commit is contained in:
commit
593a80fae5
|
@ -30,8 +30,8 @@ lrn_op_info = TBERegOp("LRN") \
|
|||
.attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ lrn_grad_op_info = TBERegOp("LRNGrad") \
|
|||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "y", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -50,27 +50,17 @@ parallel_concat_op_info = TBERegOp("ParallelConcat") \
|
|||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
|
||||
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
|
||||
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
|
||||
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
|
||||
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
|
||||
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
|
||||
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
|
||||
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
|
||||
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
|
||||
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
|
||||
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
|
||||
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
|
||||
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
|
||||
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
|
||||
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
|
||||
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
|
||||
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
|
||||
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "weight", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -28,8 +28,6 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \
|
|||
.input(2, "weights", False, "required", "all") \
|
||||
.output(0, "dx", False, "required", "all") \
|
||||
.output(0, "da", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_Default,
|
||||
DataType.F32_NCHW, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
|
|
|
@ -32,8 +32,6 @@ sparse_apply_adagrad_d_op_info = TBERegOp("SparseApplyAdagrad") \
|
|||
.input(3, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
|
||||
|
|
|
@ -33,8 +33,6 @@ sparse_apply_adagrad_v2_d_op_info = TBERegOp("SparseApplyAdagradV2") \
|
|||
.input(3, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
|
||||
|
|
|
@ -36,14 +36,10 @@ sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \
|
|||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.output(2, "linear", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.I64_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.I64_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
|
|
|
@ -37,8 +37,6 @@ sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \
|
|||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.output(2, "linear", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
|
|
|
@ -37,8 +37,6 @@ sparse_apply_ftrl_v2_d_op_info = TBERegOp("SparseApplyFtrlV2") \
|
|||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.output(2, "linear", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC,
|
||||
DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
|
|
|
@ -33,9 +33,6 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad")
|
|||
.input(6, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
|
@ -48,9 +45,6 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad")
|
|||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
|
|
|
@ -34,9 +34,6 @@ sparse_apply_proximal_adagrad_d_ds_op_info = TBERegOp("SparseApplyProximalAdagra
|
|||
.input(6, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
|
@ -49,9 +46,6 @@ sparse_apply_proximal_adagrad_d_ds_op_info = TBERegOp("SparseApplyProximalAdagra
|
|||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW,
|
||||
DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
|
|
Loading…
Reference in New Issue