!222 Modify custom op register

Merge pull request !222 from zjun/Modify_custom_op_register
This commit is contained in:
mindspore-ci-bot 2020-04-10 17:46:25 +08:00 committed by Gitee
commit 540a91728b
2 changed files with 38 additions and 132 deletions

View File

@ -13,95 +13,28 @@
# limitations under the License.
# ============================================================================
from tests.st.ops.custom_ops_tbe.conv2d import conv2d
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
@op_info_register("""{
"op_name": "Cus_Conv2D",
"imply_type": "TBE",
"fusion_type": "CONVLUTION",
"async_flag": false,
"binfile_name": "conv2d.so",
"compute_cost": 10,
"kernel_name": "Cus_Conv2D",
"partial_flag": true,
"attr": [
{
"name": "stride",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "pad_list",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "dilation",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "filter",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "bias",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
cus_conv2D_op_info = TBERegOp("Cus_Conv2D") \
.fusion_type("CONVLUTION") \
.async_flag(False) \
.binfile_name("conv2d.so") \
.compute_cost(10) \
.kernel_name("Cus_Conv2D") \
.partial_flag(True) \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "filter", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F32_Default, DataType.F16_5HD) \
.get_op_info()
@op_info_register(cus_conv2D_op_info)
def Cus_Conv2D(inputs, weights, bias, outputs, strides, pads, dilations,
kernel_name="conv2d"):
conv2d(inputs, weights, bias, outputs, strides, pads, dilations,
kernel_name)
kernel_name)

View File

@ -18,11 +18,12 @@ from topi import generic
import te.lang.cce
from topi.cce import util
from te.platform.fusion_manager import fusion_manager
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
# shape size limit for aicore is 2**31
SHAPE_SIZE_LIMIT = 200000000
@fusion_manager.register("square")
def square_compute(input_x, output_y, kernel_name="square"):
"""
@ -46,49 +47,21 @@ def square_compute(input_x, output_y, kernel_name="square"):
res = te.lang.cce.vmul(input_x, input_x)
return res
@op_info_register("""{
"op_name": "CusSquare",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "square.so",
"compute_cost": 10,
"kernel_name": "CusSquare",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
cus_conv2D_op_info = TBERegOp("CusSquare") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("square.so") \
.compute_cost(10) \
.kernel_name("CusSquare") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(cus_conv2D_op_info)
def CusSquare(input_x, output_y, kernel_name="square"):
"""
algorithm: square