forked from mindspore-Ecosystem/mindspore
Modify custom op register
This commit is contained in:
parent
268d358a1d
commit
f5ee197b6c
|
@ -13,95 +13,28 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from tests.st.ops.custom_ops_tbe.conv2d import conv2d
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Cus_Conv2D",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "CONVLUTION",
|
||||
"async_flag": false,
|
||||
"binfile_name": "conv2d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "Cus_Conv2D",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "stride",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "pad_list",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "dilation",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"FracZ"
|
||||
],
|
||||
"name": "filter",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "bias",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
cus_conv2D_op_info = TBERegOp("Cus_Conv2D") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv2d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("Cus_Conv2D") \
|
||||
.partial_flag(True) \
|
||||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "filter", False, "required", "all") \
|
||||
.input(2, "bias", False, "optional", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F32_Default, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(cus_conv2D_op_info)
|
||||
def Cus_Conv2D(inputs, weights, bias, outputs, strides, pads, dilations,
|
||||
kernel_name="conv2d"):
|
||||
conv2d(inputs, weights, bias, outputs, strides, pads, dilations,
|
||||
kernel_name)
|
||||
kernel_name)
|
||||
|
|
|
@ -18,11 +18,12 @@ from topi import generic
|
|||
import te.lang.cce
|
||||
from topi.cce import util
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
# shape size limit for aicore is 2**31
|
||||
SHAPE_SIZE_LIMIT = 200000000
|
||||
|
||||
|
||||
@fusion_manager.register("square")
|
||||
def square_compute(input_x, output_y, kernel_name="square"):
|
||||
"""
|
||||
|
@ -46,49 +47,21 @@ def square_compute(input_x, output_y, kernel_name="square"):
|
|||
res = te.lang.cce.vmul(input_x, input_x)
|
||||
return res
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "CusSquare",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "square.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "CusSquare",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
|
||||
cus_conv2D_op_info = TBERegOp("CusSquare") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("square.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("CusSquare") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(cus_conv2D_op_info)
|
||||
def CusSquare(input_x, output_y, kernel_name="square"):
|
||||
"""
|
||||
algorithm: square
|
||||
|
|
Loading…
Reference in New Issue