diff --git a/tests/st/ops/custom_ops_tbe/cus_conv2d_impl.py b/tests/st/ops/custom_ops_tbe/cus_conv2d_impl.py index 54f6954a18e..04ac7c2ff70 100644 --- a/tests/st/ops/custom_ops_tbe/cus_conv2d_impl.py +++ b/tests/st/ops/custom_ops_tbe/cus_conv2d_impl.py @@ -13,95 +13,28 @@ # limitations under the License. # ============================================================================ from tests.st.ops.custom_ops_tbe.conv2d import conv2d -from mindspore.ops.op_info_register import op_info_register +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -@op_info_register("""{ - "op_name": "Cus_Conv2D", - "imply_type": "TBE", - "fusion_type": "CONVLUTION", - "async_flag": false, - "binfile_name": "conv2d.so", - "compute_cost": 10, - "kernel_name": "Cus_Conv2D", - "partial_flag": true, - "attr": [ - { - "name": "stride", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "pad_list", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "dilation", - "param_type": "required", - "type": "listInt", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "filter", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "bias", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "y", - "need_compile": true, - "param_type": "required", - "shape": "all" - } - ] -}""") +cus_conv2D_op_info = TBERegOp("Cus_Conv2D") \ + .fusion_type("CONVLUTION") \ + .async_flag(False) \ + .binfile_name("conv2d.so") \ + .compute_cost(10) \ + .kernel_name("Cus_Conv2D") \ + .partial_flag(True) \ + .attr("stride", "required", "listInt", "all") \ + .attr("pad_list", "required", "listInt", "all") \ + .attr("dilation", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "filter", False, "required", "all") \ + .input(2, "bias", False, "optional", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F32_Default, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(cus_conv2D_op_info) def Cus_Conv2D(inputs, weights, bias, outputs, strides, pads, dilations, kernel_name="conv2d"): conv2d(inputs, weights, bias, outputs, strides, pads, dilations, - kernel_name) \ No newline at end of file + kernel_name) diff --git a/tests/st/ops/custom_ops_tbe/square_impl.py b/tests/st/ops/custom_ops_tbe/square_impl.py index e5992eff1c3..f3a1e0751de 100644 --- a/tests/st/ops/custom_ops_tbe/square_impl.py +++ b/tests/st/ops/custom_ops_tbe/square_impl.py @@ -18,11 +18,12 @@ from topi import generic import te.lang.cce from topi.cce import util from te.platform.fusion_manager import fusion_manager -from mindspore.ops.op_info_register import op_info_register +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType # shape size limit for aicore is 2**31 SHAPE_SIZE_LIMIT = 200000000 + @fusion_manager.register("square") def square_compute(input_x, output_y, kernel_name="square"): """ @@ -46,49 +47,21 @@ def square_compute(input_x, output_y, kernel_name="square"): res = te.lang.cce.vmul(input_x, input_x) return res -@op_info_register("""{ - "op_name": "CusSquare", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "square.so", - "compute_cost": 10, - "kernel_name": "CusSquare", - "partial_flag": true, - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") + +cus_conv2D_op_info = TBERegOp("CusSquare") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("square.so") \ + .compute_cost(10) \ + .kernel_name("CusSquare") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_conv2D_op_info) def CusSquare(input_x, output_y, kernel_name="square"): """ algorithm: square