forked from OSSInnovation/mindspore
Add all other tbe op info register
This commit is contained in:
parent
c06d2c6c2a
commit
02aca06451
|
@ -14,71 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Add op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
add_op_info = TBERegOp("Add") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("add") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Add",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "add.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "add",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float",
|
||||
"float", "int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(add_op_info)
|
||||
def _add_tbe():
|
||||
"""Add TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,61 +14,33 @@
|
|||
# ============================================================================
|
||||
|
||||
"""AddN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
add_n_op_info = TBERegOp("AddN") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("add_n.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("add_n") \
|
||||
.partial_flag(True) \
|
||||
.attr("n", "required", "int", "all") \
|
||||
.input(0, "x", False, "dynamic", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "AddN",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "add_n.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "add_n",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "n",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float","int32","int32","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ","DefaultFormat","NC1HWC0","FracZ"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "dynamic",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float","int32","int32","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ","DefaultFormat","NC1HWC0","FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(add_n_op_info)
|
||||
def _add_n_tbe():
|
||||
"""AddN TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,214 +14,66 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ApplyAdam op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_adam_op_info = TBERegOp("Adam") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_adam.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_adam") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "true,false", "false") \
|
||||
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "m", False, "required", "all") \
|
||||
.input(2, "v", False, "required", "all") \
|
||||
.input(3, "beta1_power", False, "required", "all") \
|
||||
.input(4, "beta2_power", False, "required", "all") \
|
||||
.input(5, "lr", False, "required", "all") \
|
||||
.input(6, "beta1", False, "required", "all") \
|
||||
.input(7, "beta2", False, "required", "all") \
|
||||
.input(8, "epsilon", False, "required", "all") \
|
||||
.input(9, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "m", False, "required", "all") \
|
||||
.output(2, "v", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Adam",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "apply_adam.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "apply_adam",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "use_locking",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "true,false",
|
||||
"default_value":"false"
|
||||
},
|
||||
{
|
||||
"name": "use_nesterov",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "true,false",
|
||||
"default_value":"false"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "m",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "v",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "beta1_power",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "beta2_power",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "lr",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "beta1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "beta2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "epsilon",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "grad",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "m",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "v",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(apply_adam_op_info)
|
||||
def _apply_adam_tbe():
|
||||
"""ApplyAdam TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,112 +14,42 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ApplyMomentum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_momentum_op_info = TBERegOp("ApplyMomentum") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_momentum.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_momentum") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "lr", False, "required", "all") \
|
||||
.input(3, "grad", False, "required", "all") \
|
||||
.input(4, "momentum", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
|
||||
DataType.F16_Default, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_Default, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
|
||||
DataType.F16_Default, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||
DataType.F32_Default, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_Default, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
||||
DataType.F32_Default, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ApplyMomentum",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "apply_momentum.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "apply_momentum",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "use_nesterov",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "true,false",
|
||||
"default_value":"false"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
|
||||
],
|
||||
"name": "accum",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "lr",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
|
||||
],
|
||||
"name": "grad",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "momentum",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(apply_momentum_op_info)
|
||||
def _apply_momentum_tbe():
|
||||
"""ApplyMomentum TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,70 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ArgMaxWithValue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
arg_max_with_value_op_info = TBERegOp("ArgMaxWithValue") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("arg_max_with_value.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("arg_max_with_value") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "indice", False, "required", "all") \
|
||||
.output(1, "values", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ArgMaxWithValue",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "arg_max_with_value.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "arg_max_with_value",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "indice",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "values",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(arg_max_with_value_op_info)
|
||||
def _arg_max_with_value_tbe():
|
||||
"""ArgMaxWithValue TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,70 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ArgMinWithValue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
arg_min_with_value_op_info = TBERegOp("ArgMaxWithValue") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("arg_min_with_value.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("arg_min_with_value") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "indice", False, "required", "all") \
|
||||
.output(1, "values", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ArgMinWithValue",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "arg_min_with_value.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "arg_min_with_value",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "indice",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "values",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(arg_min_with_value_op_info)
|
||||
def _arg_min_with_value_tbe():
|
||||
"""ArgMinWithValue TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,93 +14,43 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Assign op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
assign_op_info = TBERegOp("Assign") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("assign.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("assign") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "resource", False, "required", "all") \
|
||||
.input(1, "value", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I16_5HD, DataType.I16_5HD, DataType.I16_5HD) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U16_5HD, DataType.U16_5HD, DataType.U16_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U32_5HD, DataType.U32_5HD, DataType.U32_5HD) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U64_5HD, DataType.U64_5HD, DataType.U64_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Assign",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "assign.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "assign",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8",
|
||||
"int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int16", "int16", "int16",
|
||||
"int16", "uint16", "uint16", "uint16", "uint16", "int64", "int64", "int64", "int64",
|
||||
"uint64", "uint64", "uint64", "uint64", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ"
|
||||
],
|
||||
"name": "resource",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
|
||||
"int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", "int8", "int8", "int8", "uint8",
|
||||
"uint8", "uint8", "uint8", "int16", "int16", "int16", "int16", "uint16", "uint16", "uint16",
|
||||
"uint16", "int64", "int64", "int64", "int64", "uint64", "uint64", "uint64", "uint64", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ"
|
||||
],
|
||||
"name": "value",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", "int8", "int8",
|
||||
"int8", "uint8", "uint8", "uint8", "uint8", "int16", "int16", "int16", "int16", "uint16",
|
||||
"uint16", "uint16", "uint16", "int64", "int64", "int64", "int64",
|
||||
"uint64", "uint64", "uint64", "uint64", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(assign_op_info)
|
||||
def _assign_tbe():
|
||||
"""Assign TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,80 +14,34 @@
|
|||
# ============================================================================
|
||||
|
||||
"""AssignAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
assign_add_op_info = TBERegOp("AssignAdd") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("assignadd.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("assignadd") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "ref", False, "required", "all") \
|
||||
.input(1, "value", False, "required", "all") \
|
||||
.output(0, "output_ref", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "AssignAdd",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "assignadd.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "assignadd",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
|
||||
"int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64",
|
||||
"int64", "int64", "int64"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "ref",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
|
||||
"int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64",
|
||||
"int64", "int64", "int64"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "value",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
|
||||
"int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64",
|
||||
"int64", "int64", "int64"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output_ref",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(assign_add_op_info)
|
||||
def _assign_add_tbe():
|
||||
"""AssignAdd TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,65 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""AssignSub op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
assign_sub_op_info = TBERegOp("AssignSub") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("assign_sub.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("assign_sub") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "value", False, "required", "all") \
|
||||
.output(0, "output_ref", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "AssignSub",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "assign_sub.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "assign_sub",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int32", "int8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float", "int32", "int8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "value",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int32", "int8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "out_ref",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(assign_sub_op_info)
|
||||
def _assign_sub_tbe():
|
||||
"""AssignSub TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,31 +14,20 @@
|
|||
# ============================================================================
|
||||
|
||||
"""AtomicAddrClean op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp
|
||||
|
||||
atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("atomic_addr_clean.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("atomic_addr_clean") \
|
||||
.partial_flag(True) \
|
||||
.attr("automic_add_mem_size", "required", "listInt", "all") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "AtomicAddrClean",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "atomic_addr_clean.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "atomic_addr_clean",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "automic_add_mem_size",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
],
|
||||
"outputs": [
|
||||
]
|
||||
}""")
|
||||
@op_info_register(atomic_addr_clean_op_info)
|
||||
def _atomic_addr_clean_tbe():
|
||||
"""AtomicAddrClean TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,88 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchMatMul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
batch_matmul_op_info = TBERegOp("BatchMatMul") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batch_matmul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batch_matmul") \
|
||||
.attr("transpose_x1", "required", "bool", "all") \
|
||||
.attr("transpose_x2", "required", "bool", "all") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "bias", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BatchMatMul",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "batch_matmul.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "batch_matmul",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "transpose_x1",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "transpose_x2",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","int32","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","int32","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","int32","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "bias",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","int32","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(batch_matmul_op_info)
|
||||
def _batch_matmul_tbe():
|
||||
"""BatchMatMul TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,174 +14,45 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchNorm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
batch_norm_op_info = TBERegOp("BatchNorm") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batch_norm.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batch_norm") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "scale", False, "required", "all") \
|
||||
.input(2, "offset", False, "required", "all") \
|
||||
.input(3, "mean", False, "optional", "all") \
|
||||
.input(4, "variance", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "batch_mean", False, "required", "all") \
|
||||
.output(2, "batch_variance", False, "required", "all") \
|
||||
.output(3, "reserve_space_1", False, "optional", "all") \
|
||||
.output(4, "reserve_space_2", False, "optional", "all") \
|
||||
.output(5, "reserve_space_3", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BatchNorm",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "batch_norm.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "batch_norm",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "epsilon",
|
||||
"param_type": "required",
|
||||
"type": "float",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "data_format",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "is_training",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0", "DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "scale",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "offset",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "mean",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "variance",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0", "DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "batch_mean",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "batch_variance",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_1",
|
||||
"param_type": "optional"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_2",
|
||||
"param_type": "optional"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_3",
|
||||
"param_type": "optional"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(batch_norm_op_info)
|
||||
def _batch_norm_tbe():
|
||||
"""BatchNorm TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,181 +14,45 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchNormGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
batch_norm_grad_op_info = TBERegOp("BatchNormGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batchnormgrad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batchnormgrad") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.input(0, "y_backprop", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "scale", False, "required", "all") \
|
||||
.input(3, "reserve_space_1", False, "required", "all") \
|
||||
.input(4, "reserve_space_2", False, "required", "all") \
|
||||
.input(5, "reserve_space_3", False, "required", "all") \
|
||||
.output(0, "x_backprop", False, "required", "all") \
|
||||
.output(1, "scale_backprop", False, "required", "all") \
|
||||
.output(2, "offset_backprop", False, "required", "all") \
|
||||
.output(3, "reserve_space_4", False, "optional", "all") \
|
||||
.output(4, "reserve_space_5", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BatchNormGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "batchnormgrad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "batchnormgrad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "epsilon",
|
||||
"param_type": "optional",
|
||||
"type": "float",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "data_format",
|
||||
"param_type": "optional",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "is_training",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y_backprop",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "scale",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_3",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x_backprop",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "scale_backprop",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "offset_backprop",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_4",
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "reserve_space_5",
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(batch_norm_grad_op_info)
|
||||
def _batch_norm_grad_tbe():
|
||||
"""BatchNormGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,70 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BiasAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bias_add_grad_op_info = TBERegOp("BiasAdd") \
|
||||
.fusion_type("COMMREDUCE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bias_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bias_add") \
|
||||
.partial_flag(True) \
|
||||
.attr("data_format", "required", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "bias", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BiasAdd",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "COMMREDUCE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "bias_add.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "bias_add",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "data_format",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"int32", "float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "bias",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(bias_add_grad_op_info)
|
||||
def _bias_add_tbe():
|
||||
"""BiasAdd TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,57 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BiasAddGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
|
||||
.fusion_type("COMMREDUCE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("biasaddgrad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("biasaddgrad") \
|
||||
.partial_flag(True) \
|
||||
.attr("data_format", "required", "str", "all") \
|
||||
.input(0, "output_backprop", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BiasAddGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "COMMREDUCE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "biasaddgrad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "biasaddgrad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "data_format",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","FRACTAL_NZ","DefaultFormat"
|
||||
],
|
||||
"name": "out_backprop",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "output",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(bias_add_grad_op_info)
|
||||
def _bias_add_grad_tbe():
|
||||
"""BiasAddGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,60 +14,24 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchNorm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bn_training_reduce.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bn_training_reduce") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "sum", False, "required", "all") \
|
||||
.output(1, "square_sum", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BNTrainingReduce",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "bn_training_reduce.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "bn_training_reduce",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "sum",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "square_sum",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(bn_training_reduce_op_info)
|
||||
def _bn_training_reduce_tbe():
|
||||
"""BNTrainingReduce TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,134 +14,32 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchNormGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bn_training_reduce_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bn_training_reduce_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "x_norm", False, "required", "all") \
|
||||
.input(2, "diff_scale", False, "required", "all") \
|
||||
.input(3, "diff_offset", False, "required", "all") \
|
||||
.input(4, "scale", False, "required", "all") \
|
||||
.input(5, "batch_mean", False, "required", "all") \
|
||||
.input(6, "batch_variance", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BNTrainingReduceGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "bn_training_reduce_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "bn_training_reduce_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "epsilon",
|
||||
"param_type": "optional",
|
||||
"type": "float",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "grads",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "x_norm",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "diff_scale",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "diff_offset",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "scale",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "batch_mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "batch_variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(bn_training_reduce_grad_op_info)
|
||||
def _bn_training_reduce_grad_tbe():
|
||||
"""BNTrainingReduceGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,200 +14,40 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchNormGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bn_training_update.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bn_training_update") \
|
||||
.partial_flag(True) \
|
||||
.attr("factor", "optional", "float", "all") \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("isRef", "optional", "bool", "all", "true") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "sum", False, "required", "all") \
|
||||
.input(2, "square_sum", False, "required", "all") \
|
||||
.input(3, "scale", False, "required", "all") \
|
||||
.input(4, "offset", False, "required", "all") \
|
||||
.input(5, "mean", False, "required", "all") \
|
||||
.input(6, "variance", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "mean", False, "required", "all") \
|
||||
.output(2, "variance", False, "required", "all") \
|
||||
.output(3, "batch_mean", False, "required", "all") \
|
||||
.output(4, "batch_variance", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BNTrainingUpdate",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "bn_training_update.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "bn_training_update",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "factor",
|
||||
"param_type": "optional",
|
||||
"type": "float",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "epsilon",
|
||||
"param_type": "optional",
|
||||
"type": "float",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "isRef",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"default_value":"true",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "sum",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "square_sum",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "scale",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "offset",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "batch_mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "batch_variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(bn_training_update_op_info)
|
||||
def _bn_training_update_tbe():
|
||||
"""BNTrainingUpdate TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,109 +14,30 @@
|
|||
# ============================================================================
|
||||
|
||||
"""BatchNormGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bn_training_update_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bn_training_update_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "batch_mean", False, "required", "all") \
|
||||
.input(3, "batch_variance", False, "required", "all") \
|
||||
.output(0, "diff_scale", False, "required", "all") \
|
||||
.output(1, "diff_offset", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "BNTrainingUpdateGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "bn_training_update_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "bn_training_update_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "epsilon",
|
||||
"param_type": "optional",
|
||||
"type": "float",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "grads",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "batch_mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "batch_variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "diff_scale",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "diff_offset",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(bn_training_update_grad_op_info)
|
||||
def _bn_training_update_grad_tbe():
|
||||
"""BNTrainingUpdateGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,69 +14,42 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Cast op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
cast_op_info = TBERegOp("Cast") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("cast.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("cast") \
|
||||
.partial_flag(True) \
|
||||
.attr("dst_type", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Cast",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "cast.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "cast",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "dst_type",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float",
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"int8", "int8", "int8", "uint8", "uint8", "uint8",
|
||||
"bool", "bool", "bool", "bool", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float", "int32", "float16", "int32",
|
||||
"float16", "float", "int8", "uint8", "bool",
|
||||
"float16", "float", "int32", "float16", "float", "int32",
|
||||
"float16", "float", "int32", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(cast_op_info)
|
||||
def _cast_tbe():
|
||||
"""Cast TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,90 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ClipByNormNoDivSum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
clip_by_norm_no_div_sum_op_info = TBERegOp("ClipByNormNoDivSum") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("clip_by_norm_no_div_sum.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("clip_by_norm_no_div_sum") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.input(1, "input1", False, "required", "all") \
|
||||
.input(2, "input2", False, "required", "all") \
|
||||
.input(3, "input3", False, "required", "all") \
|
||||
.output(0, "output_y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ClipByNormNoDivSum",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "clip_by_norm_no_div_sum.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "clip_by_norm_no_div_sum",
|
||||
"partial_flag": true,
|
||||
"attr":[
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input_x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input3",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output_y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(clip_by_norm_no_div_sum_op_info)
|
||||
def _clip_by_norm_no_div_sum_tbe():
|
||||
"""ClipByNormNoDivSum TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,85 +14,30 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ClipByValue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
clip_by_value_op_info = TBERegOp("ClipByValue") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("clip_by_value.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("clip_by_value") \
|
||||
.partial_flag(True) \
|
||||
.attr("dst_type", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "clip_value_min", False, "required", "all") \
|
||||
.input(2, "clip_value_max", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ClipByValue",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "clip_by_value.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "clip_by_value",
|
||||
"partial_flag": true,
|
||||
"attr":[
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "clip_value_min",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "clip_value_max",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(clip_by_value_op_info)
|
||||
def _clip_by_value_tbe():
|
||||
"""ClipByValue TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,141 +14,44 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Concat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
concat_op_info = TBERegOp("Concat") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("concat_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("concat_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "input_values", False, "dynamic", "all") \
|
||||
.output(0, "output_data", False, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I16_5HD, DataType.I16_5HD) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U16_5HD, DataType.U16_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U32_5HD, DataType.U32_5HD) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_5HD, DataType.I64_5HD) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U64_5HD, DataType.U64_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Concat",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "concat_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "concat_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16",
|
||||
"float16",
|
||||
"float",
|
||||
"float",
|
||||
"int32",
|
||||
"int32",
|
||||
"int8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int16",
|
||||
"int64",
|
||||
"int64",
|
||||
"uint8",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"uint64",
|
||||
"bool",
|
||||
"bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "input_values",
|
||||
"need_compile": false,
|
||||
"param_type": "dynamic",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16",
|
||||
"float16",
|
||||
"float",
|
||||
"float",
|
||||
"int32",
|
||||
"int32",
|
||||
"int8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int16",
|
||||
"int64",
|
||||
"int64",
|
||||
"uint8",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"uint64",
|
||||
"bool",
|
||||
"bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "output_data",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(concat_op_info)
|
||||
def _concat_tbe():
|
||||
"""Concat TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,65 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ConfusionSoftmaxGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
confusion_softmax_grad_op_info = TBERegOp("ConfusionSoftmaxGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("confusion_softmax_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("confusion_softmax_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "grad", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ConfusionSoftmaxGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "confusion_softmax_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "confusion_softmax_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "grad",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(confusion_softmax_grad_op_info)
|
||||
def _confusion_softmax_grad_tbe():
|
||||
"""ConfusionSoftmaxGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,79 +14,44 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ConfusionTransposeD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("confusion_transpose_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("confusion_transpose_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("perm", "required", "listInt", "all") \
|
||||
.attr("shape", "required", "listInt", "all") \
|
||||
.attr("transpose_first", "required", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ConfusionTransposeD",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "confusion_transpose_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "confusion_transpose_d",
|
||||
"partial_flag": true,
|
||||
"attr":[
|
||||
{
|
||||
"name":"perm",
|
||||
"param_type":"required",
|
||||
"type":"listInt",
|
||||
"value":"all"
|
||||
},
|
||||
{
|
||||
"name":"shape",
|
||||
"param_type":"required",
|
||||
"type":"listInt",
|
||||
"value":"all"
|
||||
},
|
||||
{
|
||||
"name":"transpose_first",
|
||||
"param_type":"required",
|
||||
"type":"bool",
|
||||
"value":"all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32",
|
||||
"uint64", "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16",
|
||||
"uint32", "uint64"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ",
|
||||
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32",
|
||||
"uint64", "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16",
|
||||
"uint32", "uint64"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ",
|
||||
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(confusion_transpose_d_op_info)
|
||||
def _confusion_transpose_d_tbe():
|
||||
"""ConfusionTransposeD TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,114 +14,30 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Conv2D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv2d_op_info = TBERegOp("Conv2D") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv2d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv2d") \
|
||||
.partial_flag(True) \
|
||||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
.attr("offset_a", "optional", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "filter", False, "required", "all") \
|
||||
.input(2, "bias", False, "optional", "all") \
|
||||
.input(3, "offset_w", False, "optional", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default,
|
||||
DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Conv2D",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "CONVLUTION",
|
||||
"async_flag": false,
|
||||
"binfile_name": "conv2d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "conv2d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "stride",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "pad_list",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "dilation",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "offset_a",
|
||||
"param_type": "optional",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"FracZ"
|
||||
],
|
||||
"name": "filter",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "bias",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "offset_w",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(conv2d_op_info)
|
||||
def _conv2d_tbe():
|
||||
"""Conv2D TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,89 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Conv2DBackpropFilter op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv2d_backprop_filter_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv2d_backprop_filter_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("filter_sizes", "required", "listInt", "all") \
|
||||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_mode", "required", "str", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
.input(0, "out_backprop", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
# map to tbe kernel name conv2d_backprop_filter_d
|
||||
@op_info_register("""{
|
||||
"op_name": "Conv2DBackpropFilter",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "CONVLUTION",
|
||||
"async_flag": false,
|
||||
"binfile_name": "conv2d_backprop_filter_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "conv2d_backprop_filter_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "filter_sizes",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "stride",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "pad_mode",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "dilation",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "out_backprop",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32"
|
||||
],
|
||||
"format": [
|
||||
"FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(conv2d_backprop_filter_op_info)
|
||||
def _conv2d_backprop_filter_tbe():
|
||||
"""Conv2DBackpropFilter TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,88 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Conv2DBackpropInput op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv2d_backprop_input_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv2d_backprop_input_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("input_sizes", "required", "listInt", "all") \
|
||||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_mode", "required", "str", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
.input(0, "out_backprop", False, "required", "all") \
|
||||
.input(1, "filter", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Conv2DBackpropInput",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "CONVLUTION",
|
||||
"async_flag": false,
|
||||
"binfile_name": "conv2d_backprop_input_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "conv2d_backprop_input_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "input_sizes",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "stride",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "pad_mode",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "dilation",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "out_backprop",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"FracZ"
|
||||
],
|
||||
"name": "filter",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(conv2d_backprop_input_op_info)
|
||||
def _conv2d_backprop_input_tbe():
|
||||
"""Conv2DBackpropInput TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,71 +14,32 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Div op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
div_op_info = TBERegOp("Div") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("div.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("div") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Div",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "div.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "div",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(div_op_info)
|
||||
def _div_tbe():
|
||||
"""Div TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,76 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""DropoutdoMask op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("drop_out_do_mask.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("drop_out_do_mask") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "mask", False, "required", "all") \
|
||||
.input(2, "keep_prob", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "DropoutDoMask",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "drop_out_do_mask.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "drop_out_do_mask",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"uint8","uint8","uint8","uint8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "mask",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "keep_prob",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(drop_out_do_mask_op_info)
|
||||
def _dropout_do_mask_tbe():
|
||||
"""DropoutdoMask TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,66 +14,32 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Equal op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
equal_op_info = TBERegOp("Equal") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("equal.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("equal") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Equal",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "equal.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "equal",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(equal_op_info)
|
||||
def _equal_tbe():
|
||||
"""Equal TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Exp op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
exp_op_info = TBERegOp("Exp") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("exp.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("exp") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Exp",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "exp.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "exp",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(exp_op_info)
|
||||
def _exp_tbe():
|
||||
"""Exp TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,57 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ExpandDims op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
expand_dims_op_info = TBERegOp("ExpandDims") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("expand_dims.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("expand_dims") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ExpandDims",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "expand_dims.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "expand_dims",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(expand_dims_op_info)
|
||||
def _expand_dims_tbe():
|
||||
"""ExpandDims TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,64 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""FloorDiv op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
floordiv_op_info = TBERegOp("FloorDiv") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("floordiv.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("floordiv") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "FloorDiv",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "floordiv.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "floordiv",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(floordiv_op_info)
|
||||
def _floor_div_tbe():
|
||||
"""FloorDiv TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,93 +14,38 @@
|
|||
# ============================================================================
|
||||
|
||||
"""FusedMulAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fused_mul_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fused_mul_add") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "x3", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
|
||||
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
|
||||
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "FusedMulAdd",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "fused_mul_add.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "fused_mul_add",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x3",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(fused_mul_add_op_info)
|
||||
def _fused_mul_add_tbe():
|
||||
"""FusedMulAdd TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,86 +14,31 @@
|
|||
# ============================================================================
|
||||
|
||||
"""FusedMulAddN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fused_mul_add_n_op_info = TBERegOp("FusedMulAddN") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fused_mul_add_n.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fused_mul_add_n") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "x3", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "FusedMulAddN",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "fused_mul_add_n.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "fused_mul_add_n",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x3",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(fused_mul_add_n_op_info)
|
||||
def _fused_mul_add_n_tbe():
|
||||
"""FusedMulAddN TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,137 +14,43 @@
|
|||
# ============================================================================
|
||||
|
||||
"""FusedMulApplyMomentum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fused_mul_apply_momentum_op_info = TBERegOp("FusedMulApplyMomentum") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fused_mul_apply_momentum.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fused_mul_apply_momentum") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "lr", False, "required", "all") \
|
||||
.input(3, "x1", False, "required", "all") \
|
||||
.input(4, "momentum", False, "required", "all") \
|
||||
.input(5, "x2", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
|
||||
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
|
||||
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
||||
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "FusedMulApplyMomentum",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "fused_mul_apply_momentum.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "fused_mul_apply_momentum",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "use_nesterov",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "true,false",
|
||||
"default_value":"false"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "accum",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "lr",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "momentum",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16",
|
||||
"float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
|
||||
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
|
||||
],
|
||||
"name": "var",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(fused_mul_apply_momentum_op_info)
|
||||
def _fused_mul_apply_momentum_tbe():
|
||||
"""FusedMulApplyMomentum TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,94 +14,53 @@
|
|||
# ============================================================================
|
||||
|
||||
"""AddN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
gather_v2_op_info = TBERegOp("GatherV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("gather_v2_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("gather_v2_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \
|
||||
.dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \
|
||||
.dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
|
||||
.dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "GatherV2",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "gather_v2_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "gather_v2_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float16","float16",
|
||||
"float","float","float","float","float","float",
|
||||
"int32","int32","int32", "int32","int32","int32",
|
||||
"uint8","uint8","uint8","uint8","uint8","uint8",
|
||||
"int8","int8", "int8","int8","int8", "int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"int32","int32","int32","int64","int64","int64",
|
||||
"int32","int32","int32","int64","int64","int64",
|
||||
"int32","int32","int32","int64","int64","int64",
|
||||
"int32","int32","int32","int64","int64","int64",
|
||||
"int32","int32","int32","int64","int64","int64"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ"
|
||||
],
|
||||
"name": "indices",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float16","float16","float16",
|
||||
"float","float","float","float","float","float",
|
||||
"int32","int32","int32", "int32","int32","int32",
|
||||
"uint8","uint8","uint8","uint8","uint8","uint8",
|
||||
"int8","int8", "int8","int8","int8", "int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
|
||||
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(gather_v2_op_info)
|
||||
def _gather_v2_tbe():
|
||||
"""GatherV2 TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,51 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Gelu op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
gelu_op_info = TBERegOp("Gelu") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("gelu.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("gelu") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Gelu",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "gelu.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "gelu",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float16","float","float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float16","float","float16","float16","float16","float16","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(gelu_op_info)
|
||||
def _gelu_tbe():
|
||||
"""Gelu TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,77 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""GeluGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
gelu_grad_op_info = TBERegOp("GeluGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("gelu_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("gelu_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "y", False, "required", "all") \
|
||||
.output(0, "z", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "GeluGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "gelu_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "gelu_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "dy",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "z",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(gelu_grad_op_info)
|
||||
def _gelu_grad_tbe():
|
||||
"""GeluGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,68 +14,32 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Greater op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
greater_op_info = TBERegOp("Greater") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("greater.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("greater") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Greater",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "greater.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "greater",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(greater_op_info)
|
||||
def _greater_tbe():
|
||||
"""Greater TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,279 +14,46 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LambNextMV op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lamb_next_mv_op_info = TBERegOp("LambNextMV") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lamb_next_m_v.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lamb_next_m_v") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input1", False, "required", "all") \
|
||||
.input(1, "input2", False, "required", "all") \
|
||||
.input(2, "input3", False, "required", "all") \
|
||||
.input(3, "input4", False, "required", "all") \
|
||||
.input(4, "input5", False, "required", "all") \
|
||||
.input(5, "input6", False, "required", "all") \
|
||||
.input(6, "input7", False, "required", "all") \
|
||||
.input(7, "input8", False, "required", "all") \
|
||||
.input(8, "input9", False, "required", "all") \
|
||||
.input(9, "inputx0", False, "required", "all") \
|
||||
.input(10, "inputx1", False, "required", "all") \
|
||||
.input(11, "inputx2", False, "required", "all") \
|
||||
.input(12, "inputx3", False, "required", "all") \
|
||||
.output(0, "output1", False, "required", "all") \
|
||||
.output(1, "output2", False, "required", "all") \
|
||||
.output(2, "output3", False, "required", "all") \
|
||||
.output(3, "output4", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"LambNextMV",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"ELEMWISE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"lamb_next_m_v.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"lamb_next_m_v",
|
||||
"partial_flag":true,
|
||||
"attr":[],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":3,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input4",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":4,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input5",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":5,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input6",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":6,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input7",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":7,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input8",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":8,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input9",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":9,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx0",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":10,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":11,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":12,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":3,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output4",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(lamb_next_mv_op_info)
|
||||
def _lamb_next_mv_tbe():
|
||||
"""LambNextMV TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,279 +14,46 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LambNextMVWithDecayV1 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lamb_next_m_v_with_decay_v1.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lamb_next_m_v_with_decay_v1") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input1", False, "required", "all") \
|
||||
.input(1, "input2", False, "required", "all") \
|
||||
.input(2, "input3", False, "required", "all") \
|
||||
.input(3, "input4", False, "required", "all") \
|
||||
.input(4, "input5", False, "required", "all") \
|
||||
.input(5, "input6", False, "required", "all") \
|
||||
.input(6, "input7", False, "required", "all") \
|
||||
.input(7, "input8", False, "required", "all") \
|
||||
.input(8, "input9", False, "required", "all") \
|
||||
.input(9, "inputx0", False, "required", "all") \
|
||||
.input(10, "inputx1", False, "required", "all") \
|
||||
.input(11, "inputx2", False, "required", "all") \
|
||||
.input(12, "inputx3", False, "required", "all") \
|
||||
.output(0, "output1", False, "required", "all") \
|
||||
.output(1, "output2", False, "required", "all") \
|
||||
.output(2, "output3", False, "required", "all") \
|
||||
.output(3, "output4", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"LambNextMVWithDecayV1",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"OPAQUE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"lamb_next_m_v_with_decay_v1.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"lamb_next_m_v_with_decay_v1",
|
||||
"partial_flag":true,
|
||||
"attr":[],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":3,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input4",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":4,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input5",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":5,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input6",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":6,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input7",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":7,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input8",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":8,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input9",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":9,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx0",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":10,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":11,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":12,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"inputx3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":3,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output4",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(lamb_next_m_v_with_decay_v1_op_info)
|
||||
def _lamb_next_mv_with_decay_v1_tbe():
|
||||
"""LambNextMVWithDecayV1 TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,174 +14,35 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LambUpdateWithLr op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lamb_update_with_lr_op_info = TBERegOp("LambUpdateWithLR") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lamb_update_with_lr.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lamb_update_with_lr") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input1", False, "required", "all") \
|
||||
.input(1, "input2", False, "required", "all") \
|
||||
.input(2, "input3", False, "required", "all") \
|
||||
.input(3, "input4", False, "required", "all") \
|
||||
.input(4, "input5", False, "required", "all") \
|
||||
.input(5, "input6", False, "required", "all") \
|
||||
.input(6, "input7", False, "required", "all") \
|
||||
.input(7, "input8", False, "required", "all") \
|
||||
.input(8, "input9", False, "required", "all") \
|
||||
.output(0, "output_y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"LambUpdateWithLR",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"ELEMWISE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"lamb_update_with_lr.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"lamb_update_with_lr",
|
||||
"partial_flag":true,
|
||||
"attr":[],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":3,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input4",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":4,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input5",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":5,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input6",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":6,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input7",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":7,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input8",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":8,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"input9",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"output_y",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(lamb_update_with_lr_op_info)
|
||||
def _lamb_update_with_lr_tbe():
|
||||
"""LambUpdateWithLr TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,144 +14,31 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LambUpdateWithLrV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lamb_update_with_lr_v2_op_info = TBERegOp("LambUpdateWithLrV2") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lamb_update_with_lr_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lamb_update_with_lr_v2") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "x3", False, "required", "all") \
|
||||
.input(3, "x4", False, "required", "all") \
|
||||
.input(4, "x5", False, "required", "all") \
|
||||
.input(5, "greater_y", False, "required", "all") \
|
||||
.input(6, "select_e", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"LambUpdateWithLrV2",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"ELEMWISE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"lamb_update_with_lr_v2.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"lamb_update_with_lr_v2",
|
||||
"partial_flag":true,
|
||||
"attr":[],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"x1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"x2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"x3",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":3,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"x4",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":4,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"x5",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":5,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"greater_y",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":6,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"select_e",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"y",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(lamb_update_with_lr_v2_op_info)
|
||||
def _lamb_update_with_lr_v2_tbe():
|
||||
"""LambUpdateWithLrV2 TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,111 +14,39 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LayerNorm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_op_info = TBERegOp("LayerNorm") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm") \
|
||||
.partial_flag(True) \
|
||||
.attr("begin_norm_axis", "required", "int", "all") \
|
||||
.attr("begin_params_axis", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "gamma", False, "required", "all") \
|
||||
.input(2, "beta", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "mean", False, "required", "all") \
|
||||
.output(2, "variance", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_FracNZ,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_FracNZ,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LayerNorm",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "layer_norm.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "layer_norm",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "begin_norm_axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "begin_params_axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "gamma",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "beta",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
|
||||
],
|
||||
"name": "mean",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
|
||||
],
|
||||
"name": "variance",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(layer_norm_op_info)
|
||||
def _layer_norm_tbe():
|
||||
"""LayerNorm TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,105 +14,38 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LayerNormBetaGammaBackprop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_beta_gamma_backprop.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_beta_gamma_backprop") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape_gamma", "required", "listInt", "all") \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "variance", False, "required", "all") \
|
||||
.input(3, "mean", False, "required", "all") \
|
||||
.output(0, "pd_gamma", False, "required", "all") \
|
||||
.output(1, "pd_beta", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LayerNormBetaGammaBackprop",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "layer_norm_beta_gamma_backprop.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "layer_norm_beta_gamma_backprop",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "shape_gamma",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "dy",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "pd_gamma",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float","float","float","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "pd_beta",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(layer_norm_beta_gamma_backprop_op_info)
|
||||
def _layer_norm_beta_gamma_backprop_tbe():
|
||||
"""LayerNormBetaGammaBackprop TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,124 +14,35 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LayerNormGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_grad_op_info = TBERegOp("LayerNormGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "variance", False, "required", "all") \
|
||||
.input(3, "mean", False, "required", "all") \
|
||||
.input(4, "gamma", False, "required", "all") \
|
||||
.output(0, "pd_x", False, "required", "all") \
|
||||
.output(1, "pd_gamma", False, "required", "all") \
|
||||
.output(2, "pd_beta", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LayerNormGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "layer_norm_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "layer_norm_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "dy",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "gamma",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "pd_x",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "pd_gamma",
|
||||
"param_type": "required"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "pd_beta",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(layer_norm_grad_op_info)
|
||||
def _layer_norm_grad_tbe():
|
||||
"""LayerNormGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,102 +14,37 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LayerNormXBackprop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_x_backprop.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_x_backprop") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "variance", False, "required", "all") \
|
||||
.input(3, "mean", False, "required", "all") \
|
||||
.input(4, "gamma", False, "required", "all") \
|
||||
.output(0, "pd_x", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LayerNormXBackprop",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "layer_norm_x_backprop.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "layer_norm_x_backprop",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "dy",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "variance",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "mean",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "gamma",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float16","float","float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "pd_x",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(layer_norm_x_backprop_op_info)
|
||||
def _layer_norm_x_backprop_tbe():
|
||||
"""LayerNormXBackprop TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,67 +14,32 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Less op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
less_op_info = TBERegOp("Less") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("less.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("less") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Less",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "less.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "less",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(less_op_info)
|
||||
def _less_tbe():
|
||||
"""Less TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,67 +14,34 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LessEqual op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
less_equal_op_info = TBERegOp("LessEqual") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("less_equal.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("less_equal") \
|
||||
.partial_flag(True) \
|
||||
.attr("begin_norm_axis", "required", "int", "all") \
|
||||
.attr("begin_params_axis", "required", "int", "all") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LessEqual",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "less_equal.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "less_equal",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
|
||||
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(less_equal_op_info)
|
||||
def _less_equal_tbe():
|
||||
"""LessEqual TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Log op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
log_op_info = TBERegOp("Log") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("log.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("log") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Log",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "log.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "log",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(log_op_info)
|
||||
def _log_tbe():
|
||||
"""Log TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,65 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LogicalAnd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
logical_and_op_info = TBERegOp("LogicalAnd") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("logical_and.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("logical_and") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \
|
||||
.dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LogicalAnd",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "logical_and.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "logical_and",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(logical_and_op_info)
|
||||
def _logical_and_tbe():
|
||||
"""LogicalAnd TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LogicalNot op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
logical_not_op_info = TBERegOp("LogicalNot") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("logical_not.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("logical_not") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ) \
|
||||
.dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LogicalNot",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "logical_not.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "logical_not",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(logical_not_op_info)
|
||||
def _logical_not_tbe():
|
||||
"""LogicalNot TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,65 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LogicalOr op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
logical_or_op_info = TBERegOp("LogicalOr") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("logical_or.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("logical_or") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \
|
||||
.dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LogicalOr",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "logical_or.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "logical_or",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(logical_or_op_info)
|
||||
def _logical_or_tbe():
|
||||
"""LogicalOr TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,57 +14,24 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LogSoftmax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
log_softmax_op_info = TBERegOp("LogSoftmax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("log_softmax.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("log_softmax") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.input(0, "logits", False, "required", "all") \
|
||||
.output(0, "logsoftmax", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LogSoftmax",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "log_softmax.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "log_softmax",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "logits",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "logsoftmax",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(log_softmax_op_info)
|
||||
def _logsoftmax_tbe():
|
||||
"""LogSoftMaxGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,70 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""LogSoftmaxGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
log_softmax_grad_op_info = TBERegOp("LogSoftmaxGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("log_softmax_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("log_softmax_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "grad", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "LogSoftmaxGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "log_softmax_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "log_softmax_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "grad",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(log_softmax_grad_op_info)
|
||||
def _logsoftmax_grad_tbe():
|
||||
"""LogSoftMaxGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,89 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MatMul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
matmul_op_info = TBERegOp("MatMul") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("matmul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("matmul") \
|
||||
.partial_flag(True) \
|
||||
.attr("transpose_a", "required", "bool", "all") \
|
||||
.attr("transpose_b", "required", "bool", "all") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "x3", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "MatMul",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "matmul.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "matmul",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "transpose_a",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "transpose_b",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float16","float","int32"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float16","float","int32"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float","float","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x3",
|
||||
"need_compile": false,
|
||||
"param_type": "optional",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float","int32"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(matmul_op_info)
|
||||
def _matmul_tbe():
|
||||
"""Mul TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,74 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MaxPool op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
max_pool_op_info = TBERegOp("MaxPool") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("max_pool.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("max_pool") \
|
||||
.partial_flag(True) \
|
||||
.attr("ksize", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.attr("data_format", "required", "str", "all") \
|
||||
.input(0, "input_data", False, "required", "all") \
|
||||
.output(0, "output_data", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "MaxPool",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "max_pool.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "max_pool",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "ksize",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "strides",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "padding",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "data_format",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "input_data",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "output_data",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(max_pool_op_info)
|
||||
def _max_pool_tbe():
|
||||
"""MaxPool TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,93 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MaxPoolGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
max_pool_grad_op_info = TBERegOp("MaxPoolGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("max_pool_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("max_pool_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("ksize", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "grad", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "MaxPoolGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "max_pool_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "max_pool_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "ksize",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "strides",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "padding",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "grad",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(max_pool_grad_op_info)
|
||||
def _max_pool_grad_tbe():
|
||||
"""MaxPoolGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,95 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MaxPoolGradWithArgmax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
max_pool_grad_with_argmax_op_info = TBERegOp("MaxPoolGradWithArgmax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("max_pool_grad_with_argmax.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("max_pool_grad_with_argmax") \
|
||||
.partial_flag(True) \
|
||||
.attr("ksize", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "grad", False, "required", "all") \
|
||||
.input(2, "argmax", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "MaxPoolGradWithArgmax",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "max_pool_grad_with_argmax.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "max_pool_grad_with_argmax",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "ksize",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "strides",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "padding",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "grad",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"uint16", "int64"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "argmax",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(max_pool_grad_with_argmax_op_info)
|
||||
def _max_pool_grad_with_argmax_tbe():
|
||||
"""MaxPoolGradWithArgmax TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,82 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MaxPoolWithArgmax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
max_pool_with_argmax_op_info = TBERegOp("MaxPoolWithArgmax") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("max_pool_with_argmax.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("max_pool_with_argmax") \
|
||||
.partial_flag(True) \
|
||||
.attr("ksize", "required", "listInt", "all") \
|
||||
.attr("strides", "required", "listInt", "all") \
|
||||
.attr("padding", "required", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "argmax", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "MaxPoolWithArgmax",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "CONVLUTION",
|
||||
"async_flag": false,
|
||||
"binfile_name": "max_pool_with_argmax.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "max_pool_with_argmax",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "ksize",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "strides",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "padding",
|
||||
"param_type": "required",
|
||||
"type": "str",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"uint16"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "argmax",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(max_pool_with_argmax_op_info)
|
||||
def _max_pool_with_argmax_tbe():
|
||||
"""MaxPoolWithArgmax TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,69 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Maximum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
maximum_op_info = TBERegOp("Maximum") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("maximum.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("maximum") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"Maximum",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"ELEMWISE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"maximum.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"maximum",
|
||||
"partial_flag":true,
|
||||
"attr":[],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"x1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"x2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"y",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(maximum_op_info)
|
||||
def _maximum_tbe():
|
||||
"""Maximum TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,112 +14,38 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MaximumGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
maximum_grad_op_info = TBERegOp("MaximumGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("maximum_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("maximum_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("grad_x", "optional", "bool", "all") \
|
||||
.attr("grad_y", "optional", "bool", "all") \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.output(0, "y1", False, "required", "all") \
|
||||
.output(1, "y2", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD,
|
||||
DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"MaximumGrad",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"OPAQUE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"maximum_grad.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"maximum_grad",
|
||||
"partial_flag":true,
|
||||
"attr":[
|
||||
{
|
||||
"name":"grad_x",
|
||||
"param_type":"optional",
|
||||
"type":"bool",
|
||||
"value":"all"
|
||||
},
|
||||
{
|
||||
"name":"grad_y",
|
||||
"param_type":"optional",
|
||||
"type":"bool",
|
||||
"value":"all"
|
||||
}
|
||||
],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"grads",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"x1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"x2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"y1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"y2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(maximum_grad_op_info)
|
||||
def _maximum_grad_tbe():
|
||||
"""MaximumGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -15,74 +15,28 @@
|
|||
|
||||
|
||||
"""Minimum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
minimum_op_info = TBERegOp("Minimum") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("minimum.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("minimum") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Minimum",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "minimum.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "minimum",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(minimum_op_info)
|
||||
def _minimum_tbe():
|
||||
"""Minimum TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,112 +14,38 @@
|
|||
# ============================================================================
|
||||
|
||||
"""MinimumGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
minimum_grad_op_info = TBERegOp("MinimumGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("minimum_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("minimum_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("grad_x", "optional", "bool", "all") \
|
||||
.attr("grad_y", "optional", "bool", "all") \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.output(0, "y1", False, "required", "all") \
|
||||
.output(1, "y2", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD,
|
||||
DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"MinimumGrad",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"OPAQUE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"minimum_grad.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"minimum_grad",
|
||||
"partial_flag":true,
|
||||
"attr":[
|
||||
{
|
||||
"name":"grad_x",
|
||||
"param_type":"optional",
|
||||
"type":"bool",
|
||||
"value":"all"
|
||||
},
|
||||
{
|
||||
"name":"grad_y",
|
||||
"param_type":"optional",
|
||||
"type":"bool",
|
||||
"value":"all"
|
||||
}
|
||||
],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"grads",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"x1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":2,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"x2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"y1",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
},
|
||||
{
|
||||
"index":1,
|
||||
"dtype":[
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name":"y2",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(minimum_grad_op_info)
|
||||
def _minimum_grad_tbe():
|
||||
"""MinimumGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,77 +14,37 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Mul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
mul_op_info = TBERegOp("Mul") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("mul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("mul") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
|
||||
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
|
||||
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Mul",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "mul.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "mul",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32",
|
||||
"float16", "float16", "float16", "float16", "float16",
|
||||
"float", "float", "float", "float","float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
|
||||
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
|
||||
],
|
||||
"name": "output",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(mul_op_info)
|
||||
def _mul_tbe():
|
||||
"""Mul TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,51 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Neg op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
neg_op_info = TBERegOp("Neg") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("neg.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("neg") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Neg",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "neg.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "neg",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float","float","float16","float16","int32","int32","int8","int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float","float","float16","float16","int32","int32","int8","int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(neg_op_info)
|
||||
def _neg_tbe():
|
||||
"""Neg TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,39 +14,21 @@
|
|||
# ============================================================================
|
||||
|
||||
"""NPUAllocFloatStatus op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
npu_alloc_float_status_op_info = TBERegOp("NPUAllocFloatStatus") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("n_p_u_alloc_float_status.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("n_p_u_alloc_float_status") \
|
||||
.partial_flag(True) \
|
||||
.output(0, "data", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "NPUAllocFloatStatus",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "n_p_u_alloc_float_status.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "n_p_u_alloc_float_status",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "data",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(npu_alloc_float_status_op_info)
|
||||
def _npu_alloc_float_status_tbe():
|
||||
"""NPUAllocFloatStatus TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,22 @@
|
|||
# ============================================================================
|
||||
|
||||
"""NPUClearFloatStatus op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
npu_clear_float_status_op_info = TBERegOp("NPUClearFloatStatus") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("n_p_u_clear_float_status.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("n_p_u_clear_float_status") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "addr", False, "required", "all") \
|
||||
.output(0, "data", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "NPUClearFloatStatus",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "n_p_u_clear_float_status.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "n_p_u_clear_float_status",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "addr",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "data",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(npu_clear_float_status_op_info)
|
||||
def _npu_clear_float_status_tbe():
|
||||
"""NPUClearFloatStatus TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,22 @@
|
|||
# ============================================================================
|
||||
|
||||
"""NPUGetFloatStatus op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
npu_get_float_status_op_info = TBERegOp("NPUGetFloatStatus") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("n_p_u_get_float_status.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("n_p_u_get_float_status") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "addr", False, "required", "all") \
|
||||
.output(0, "data", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "NPUGetFloatStatus",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "n_p_u_get_float_status.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "n_p_u_get_float_status",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "addr",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name": "data",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(npu_get_float_status_op_info)
|
||||
def _npu_get_float_status_tbe():
|
||||
"""NPUGetFloatStatus TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,96 +14,35 @@
|
|||
# ============================================================================
|
||||
|
||||
"""OneHot op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
one_hot_op_info = TBERegOp("OneHot") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("one_hot.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("one_hot") \
|
||||
.partial_flag(True) \
|
||||
.attr("depth", "required", "int", "all") \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "on_value", False, "required", "all") \
|
||||
.input(2, "off_value", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.U8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "OneHot",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "one_hot.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "one_hot",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "depth",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32","int32","int32","int32","int32",
|
||||
"uint8","uint8","uint8","uint8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float32","int32","int8","uint8",
|
||||
"float16","float32","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "on_value",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16","float32","int32","int8","uint8",
|
||||
"float16","float32","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "off_value",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float32","int32","int8","uint8",
|
||||
"float16","float32","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(one_hot_op_info)
|
||||
def _one_hot_tbe():
|
||||
"""OneHot TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,57 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Pad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
pad_d_op_info = TBERegOp("Pad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("pad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("pad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("paddings", "optional", "listListInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Pad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "pad_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "pad_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "paddings",
|
||||
"param_type": "optional",
|
||||
"type": "listListInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int8","uint8","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int8","uint8","int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(pad_d_op_info)
|
||||
def _pad_d_tbe():
|
||||
"""Pad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,65 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Pow op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
pow_op_info = TBERegOp("Pow") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("pow.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("pow") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Pow",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "pow.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "pow",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int32", "int8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float", "int32", "int8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int32", "int8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(pow_op_info)
|
||||
def _pow_tbe():
|
||||
"""Pow TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,64 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""RealDiv op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
realdiv_op_info = TBERegOp("RealDiv") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("realdiv.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("realdiv") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "RealDiv",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "realdiv.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "realdiv",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "z",
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(realdiv_op_info)
|
||||
def _real_div_tbe():
|
||||
"""RealDiv TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Add op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reciprocal_op_info = TBERegOp("Reciprocal") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reciprocal.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reciprocal") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Reciprocal",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reciprocal.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reciprocal",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float32", "float32", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "NHWC", "DefaultFormat", "NC1HWC0", "NHWC"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float32", "float32", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "NHWC", "DefaultFormat", "NC1HWC0", "NHWC"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reciprocal_op_info)
|
||||
def _reciprocal_tbe():
|
||||
"""Add TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReduceMax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reduce_max_d_op_info = TBERegOp("ReduceMax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reduce_max_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reduce_max_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.attr("keep_dims", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReduceMax",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reduce_max_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reduce_max_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "keep_dims",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int8", "uint8", "bool", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float", "int8", "uint8", "bool", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reduce_max_d_op_info)
|
||||
def _reduce_max_tbe():
|
||||
"""ReduceMax TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReduceMean op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reduce_mean_op_info = TBERegOp("ReduceMean") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reduce_mean.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reduce_mean") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.attr("keep_dims", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReduceMean",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reduce_mean.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reduce_mean",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "keep_dims",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float16","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float16","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reduce_mean_op_info)
|
||||
def _reduce_mean_tbe():
|
||||
"""ReduceMean TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReduceMeanD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reduce_mean_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reduce_mean_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.attr("keep_dims", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReduceMeanD",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reduce_mean_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reduce_mean_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "keep_dims",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float","float16","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float","float16","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reduce_mean_d_op_info)
|
||||
def _reduce_mean_d_tbe():
|
||||
"""Conv2D TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,31 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReduceMin op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reduce_min_op_info = TBERegOp("ReduceMin") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reduce_min_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reduce_min_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "listInt", "all") \
|
||||
.attr("keep_dims", "required", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReduceMin",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reduce_min_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reduce_min_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "keep_dims",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reduce_min_op_info)
|
||||
def _reduce_min_tbe():
|
||||
"""ReduceMin TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReduceSum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reduce_sum_op_info = TBERegOp("ReduceSum") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reduce_sum_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reduce_sum_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.attr("keep_dims", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReduceSum",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reduce_sum_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reduce_sum_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "keep_dims",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reduce_sum_op_info)
|
||||
def _reduce_sum_tbe():
|
||||
"""ReduceSum TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,54 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReLU op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
relu_op_info = TBERegOp("ReLU") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("relu.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("relu") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReLU",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "relu.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "relu",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float","int32", "int32", "int8", "int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(relu_op_info)
|
||||
def _relu_tbe():
|
||||
"""Relu TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,68 +14,32 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ReluGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
relugrad_op_info = TBERegOp("ReluGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("relugrad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("relugrad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "gradients", False, "required", "all") \
|
||||
.input(1, "features", False, "required", "all") \
|
||||
.output(0, "backprops", True, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReluGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "relugrad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "relugrad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0","DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "gradients",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "features",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
|
||||
],
|
||||
"name": "backprops",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(relugrad_op_info)
|
||||
def _relu_grad_tbe():
|
||||
"""ReluGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,57 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Reshape op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reshape_op_info = TBERegOp("Reshape") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reshape.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reshape") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Reshape",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reshape.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reshape",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "shape",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(reshape_op_info)
|
||||
def _reshape_tbe():
|
||||
"""Reshape TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,67 +14,33 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ResizeNearestNeighbor op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("resize_nearest_neighbor_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("resize_nearest_neighbor_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("size", "required", "listInt", "all") \
|
||||
.attr("align_corners", "optional", "bool", "all") \
|
||||
.input(0, "images", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ResizeNearestNeighbor",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "resize_nearest_neighbor_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "resize_nearest_neighbor_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "size",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "align_corners",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8",
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "images",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8",
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0",
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(resize_nearest_neighbor_op_info)
|
||||
def _resize_nearest_neighbor_d_tbe():
|
||||
"""ResizeNearestNeighbor TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ResizeNearestNeighbor op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
resize_nearest_neighbor_d_op_info = TBERegOp("ResizeNearestNeighbor") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("resize_nearest_neighbor_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("resize_nearest_neighbor_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("size", "required", "listInt", "all") \
|
||||
.attr("align_corners", "optional", "bool", "all") \
|
||||
.input(0, "images", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ResizeNearestNeighbor",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "resize_nearest_neighbor_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "resize_nearest_neighbor_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "size",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "align_corners",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "images",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(resize_nearest_neighbor_d_op_info)
|
||||
def _resize_nearest_neighbor_d_tbe():
|
||||
"""ResizeNearestNeighbor TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,63 +14,24 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ResizeNearestNeighborgrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
resize_nearest_neighbor_grad_d_op_info = TBERegOp("ResizeNearestNeighborGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("resize_nearest_neighbor_grad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("resize_nearest_neighbor_grad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("size", "required", "listInt", "all") \
|
||||
.attr("align_corners", "optional", "bool", "all") \
|
||||
.input(0, "grads", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ResizeNearestNeighborGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "resize_nearest_neighbor_grad_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "resize_nearest_neighbor_grad_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "size",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "align_corners",
|
||||
"param_type": "optional",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "grads",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(resize_nearest_neighbor_grad_d_op_info)
|
||||
def _resize_nearest_neighbor_grad_d_tbe():
|
||||
"""ResizeNearestNeighborGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,52 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Round op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
round_op_info = TBERegOp("Round") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("round.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("round") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Round",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "round.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "round",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(round_op_info)
|
||||
def _round_tbe():
|
||||
"""Round TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,94 +14,29 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Rsqrt op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
rsqrt_op_info = TBERegOp("Rsqrt") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("rsqrt.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("rsqrt") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"Rsqrt",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"OPAQUE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"rsqrt.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"rsqrt",
|
||||
"partial_flag":true,
|
||||
"attr":[],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float",
|
||||
"float",
|
||||
"float",
|
||||
"float",
|
||||
"float",
|
||||
"float"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"FracZ",
|
||||
"C1HWNCoC0",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"FracZ",
|
||||
"C1HWNCoC0"
|
||||
],
|
||||
"name":"x",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float16",
|
||||
"float",
|
||||
"float",
|
||||
"float",
|
||||
"float",
|
||||
"float",
|
||||
"float"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"FracZ",
|
||||
"C1HWNCoC0",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"NC1HWC0",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"FracZ",
|
||||
"C1HWNCoC0"
|
||||
],
|
||||
"name":"y",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(rsqrt_op_info)
|
||||
def _rsqrt_tbe():
|
||||
"""Rsqrt TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,71 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ScatterNd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_nd_op_info = TBERegOp("ScatterNd") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_nd_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_nd_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape", "optional", "listInt", "all") \
|
||||
.input(0, "indices", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
# map to tbe kernel name scatter_nd_d
|
||||
@op_info_register("""{
|
||||
"op_name": "ScatterNd",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "scatter_nd_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "scatter_nd_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "shape",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "indices",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(scatter_nd_op_info)
|
||||
def _scatter_nd_tbe():
|
||||
"""Conv2D TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,70 +14,28 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ScatterNdD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_nd_d_op_info = TBERegOp("ScatterNdD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_nd_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_nd_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape", "optional", "listInt", "all") \
|
||||
.input(0, "indices", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ScatterNdD",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "scatter_nd_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "scatter_nd_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "shape",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"int32", "int32", "int32", "int32", "int32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "indices",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","int32","int8","uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(scatter_nd_d_op_info)
|
||||
def _scatter_nd_d_tbe():
|
||||
"""ScatterNdD TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,94 +14,33 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Select op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
select_op_info = TBERegOp("Select") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("select.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("select") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "condition", False, "required", "all") \
|
||||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Select",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "select.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "select",
|
||||
"partial_flag": true,
|
||||
"attr":[
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool",
|
||||
"bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat",
|
||||
"NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "condition",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
|
||||
"int32", "int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8",
|
||||
"uint8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat",
|
||||
"DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
|
||||
"int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
|
||||
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(select_op_info)
|
||||
def _select_tbe():
|
||||
"""Select TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,67 +14,31 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Sigmoid op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sigmoid_op_info = TBERegOp("Sigmoid") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sigmoid.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sigmoid") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Sigmoid",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "Sigmoid.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "sigmoid",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float",
|
||||
"float16","float",
|
||||
"float16","float",
|
||||
"float16","float",
|
||||
"float16","float"
|
||||
],
|
||||
"format": [
|
||||
"FracZ","FracZ",
|
||||
"FRACTAL_NZ","FRACTAL_NZ",
|
||||
"C1HWNCoC0","C1HWNCoC0",
|
||||
"NC1HWC0","NC1HWC0",
|
||||
"DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float",
|
||||
"float16","float",
|
||||
"float16","float",
|
||||
"float16","float",
|
||||
"float16","float"
|
||||
],
|
||||
"format": [
|
||||
"FracZ","FracZ",
|
||||
"FRACTAL_NZ","FRACTAL_NZ",
|
||||
"C1HWNCoC0","C1HWNCoC0",
|
||||
"NC1HWC0","NC1HWC0",
|
||||
"DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(sigmoid_op_info)
|
||||
def _sigmoid_tbe():
|
||||
"""Sigmoid TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,64 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""SigmoidCrossEntropyWithLogits op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidCrossEntropyWithLogits") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sigmoid_cross_entropy_with_logits.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sigmoid_cross_entropy_with_logits") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "target", False, "required", "all") \
|
||||
.output(0, "loss", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SigmoidCrossEntropyWithLogits",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "sigmoid_cross_entropy_with_logits.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "sigmoid_cross_entropy_with_logits",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "predict",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "target",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "loss",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(sigmoid_cross_entropy_with_logits_op_info)
|
||||
def _sigmoid_cross_entropy_with_logits_tbe():
|
||||
"""SigmoidCrossEntropyWithLogits TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,77 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""SigmoidCrossEntropyWithLogitsGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sigmoid_cross_entropy_with_logits_grad_op_info = TBERegOp("SigmoidCrossEntropyWithLogitsGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sigmoid_cross_entropy_with_logits_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sigmoid_cross_entropy_with_logits_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "target", False, "required", "all") \
|
||||
.input(2, "dout", False, "required", "all") \
|
||||
.output(0, "gradient", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SigmoidCrossEntropyWithLogitsGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "sigmoid_cross_entropy_with_logits_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "sigmoid_cross_entropy_with_logits_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "predict",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "target",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "dout",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
|
||||
],
|
||||
"name": "gradient",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(sigmoid_cross_entropy_with_logits_grad_op_info)
|
||||
def _sigmoid_cross_entropy_with_logits_grad_tbe():
|
||||
"""SigmoidCrossEntropyWithLogitsGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,64 +14,26 @@
|
|||
# ============================================================================
|
||||
|
||||
"""SigmoidGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sigmoid_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sigmoid_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SigmoidGrad",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "sigmoid_grad.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "sigmoid_grad",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float16","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16","float","float16","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16","float","float16","float"
|
||||
],
|
||||
"format": [
|
||||
"NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||
],
|
||||
"name": "z",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(sigmoid_cross_entropy_with_logits_op_info)
|
||||
def _sigmoid_grad_tbe():
|
||||
"""SigmoidGrad TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,99 +14,33 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Slice op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
slice_op_info = TBERegOp("Slice") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("slice_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("slice_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("begin", "required", "listInt", "all") \
|
||||
.attr("size", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name":"Slice",
|
||||
"imply_type":"TBE",
|
||||
"fusion_type":"OPAQUE",
|
||||
"async_flag":false,
|
||||
"binfile_name":"slice_d.so",
|
||||
"compute_cost":10,
|
||||
"kernel_name":"slice_d",
|
||||
"partial_flag":true,
|
||||
"attr":[
|
||||
{
|
||||
"name":"begin",
|
||||
"param_type":"required",
|
||||
"type":"listInt",
|
||||
"value":"all"
|
||||
},
|
||||
{
|
||||
"name":"size",
|
||||
"param_type":"required",
|
||||
"type":"listInt",
|
||||
"value":"all"
|
||||
}
|
||||
],
|
||||
"inputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float",
|
||||
"float16",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"x",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
],
|
||||
"outputs":[
|
||||
{
|
||||
"index":0,
|
||||
"dtype":[
|
||||
"float",
|
||||
"float16",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64"
|
||||
],
|
||||
"format":[
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat",
|
||||
"DefaultFormat"
|
||||
],
|
||||
"name":"y",
|
||||
"need_compile":false,
|
||||
"param_type":"required",
|
||||
"shape":"all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(slice_op_info)
|
||||
def _slice_tbe():
|
||||
"""Slice TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,57 +14,27 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Softmax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
softmax_op_info = TBERegOp("Softmax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("softmax.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("softmax") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "optional", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Softmax",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "softmax.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "softmax",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(softmax_op_info)
|
||||
def _softmax_tbe():
|
||||
"""Softmax TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,78 +14,25 @@
|
|||
# ============================================================================
|
||||
|
||||
"""SoftmaxCrossEntropyWithLogits op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
softmax_cross_entropy_with_logits_op_info = TBERegOp("SoftmaxCrossEntropyWithLogits") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("softmax_cross_entropy_with_logits.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("softmax_cross_entropy_with_logits") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_features", False, "required", "all") \
|
||||
.input(1, "input_labels", False, "required", "all") \
|
||||
.output(0, "output_loss", True, "required", "all") \
|
||||
.output(1, "output_backprop", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SoftmaxCrossEntropyWithLogits",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "softmax_cross_entropy_with_logits.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "softmax_cross_entropy_with_logits",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input_features",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input_labels",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output_loss",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output_backprop",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(softmax_cross_entropy_with_logits_op_info)
|
||||
def _softmax_cross_entropy_with_logits_tbe():
|
||||
"""SoftmaxCrossEntropyWithLogits TBE register"""
|
||||
return
|
||||
|
|
|
@ -14,71 +14,45 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Add op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
split_d_op_info = TBERegOp("Split") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("split_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("split_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.attr("output_num", "required", "int", "all") \
|
||||
.input(0, "value", False, "required", "all") \
|
||||
.output(0, "output", False, "dynamic", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Split",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "split_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "split_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "output_num",
|
||||
"param_type": "required",
|
||||
"type": "int",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16","float32", "float32", "int32", "int32", "int8", "int8",
|
||||
"int16", "int16", "int64", "int64", "uint8", "uint8", "uint16", "uint16",
|
||||
"uint32", "uint32", "uint64", "uint64", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
|
||||
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
|
||||
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
|
||||
],
|
||||
"name": "value",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16","float32", "float32", "int32", "int32", "int8", "int8",
|
||||
"int16", "int16", "int64", "int64", "uint8", "uint8", "uint16", "uint16",
|
||||
"uint32", "uint32", "uint64", "uint64", "bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
|
||||
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
|
||||
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
|
||||
],
|
||||
"name": "output",
|
||||
"need_compile": false,
|
||||
"param_type": "dynamic",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(split_d_op_info)
|
||||
def _split_d_tbe():
|
||||
"""Add TBE register"""
|
||||
return
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue