add example for reg ops

This commit is contained in:
lianliguang 2021-07-07 11:41:14 +08:00
parent 7ff4909f61
commit 2c836f0dec
1 changed files with 199 additions and 5 deletions

View File

@ -15,10 +15,12 @@
"""Operators info register.""" """Operators info register."""
import os
import json
import inspect import inspect
import json
import os
from mindspore._c_expression import Oplib from mindspore._c_expression import Oplib
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
# path of built-in op info register. # path of built-in op info register.
@ -37,9 +39,32 @@ def op_info_register(op_info):
Args: Args:
op_info (str or dict): operator information in json format. op_info (str or dict): operator information in json format.
Examples:
>>> from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
>>> abs_op_info = TBERegOp("Abs") \
... .fusion_type("ELEMWISE") \
... .async_flag(False) \
... .binfile_name("abs.so") \
... .compute_cost(10) \
... .kernel_name("abs") \
... .partial_flag(True) \
... .op_pattern("formatAgnostic") \
... .input(0, "x", None, "required", None) \
... .output(0, "y", True, "required", "all") \
... .dtype_format(DataType.F16_None, DataType.F16_None) \
... .dtype_format(DataType.F32_None, DataType.F32_None) \
... .dtype_format(DataType.I32_None, DataType.I32_None) \
... .get_op_info()
>>>
>>> @op_info_register(abs_op_info)
... def _abs_tbe():
... return
...
Returns: Returns:
Function, returns a decorator for op info register. Function, returns a decorator for op info register.
""" """
def register_decorator(func): def register_decorator(func):
if isinstance(op_info, dict): if isinstance(op_info, dict):
op_info_real = json.dumps(op_info) op_info_real = json.dumps(op_info)
@ -58,7 +83,9 @@ def op_info_register(op_info):
def wrapped_function(*args, **kwargs): def wrapped_function(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapped_function return wrapped_function
return register_decorator return register_decorator
@ -329,12 +356,14 @@ class AkgRegOp(RegOp):
class AkgGpuRegOp(AkgRegOp): class AkgGpuRegOp(AkgRegOp):
"""Class for AkgGpu op info register""" """Class for AkgGpu op info register"""
def __init__(self, op_name): def __init__(self, op_name):
super(AkgGpuRegOp, self).__init__(op_name, "CUDA") super(AkgGpuRegOp, self).__init__(op_name, "CUDA")
class AkgAscendRegOp(AkgRegOp): class AkgAscendRegOp(AkgRegOp):
"""Class for AkgAscend op info register""" """Class for AkgAscend op info register"""
def __init__(self, op_name): def __init__(self, op_name):
super(AkgAscendRegOp, self).__init__(op_name, "AiCore") super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
@ -348,7 +377,12 @@ class AiCPURegOp(CpuRegOp):
class TBERegOp(RegOp): class TBERegOp(RegOp):
"""Class for TBE operator information register.""" """
Class for TBE operator information register.
Args:
op_name (string):kernel name.
"""
def __init__(self, op_name): def __init__(self, op_name):
super(TBERegOp, self).__init__(op_name) super(TBERegOp, self).__init__(op_name)
@ -535,9 +569,169 @@ class TBERegOp(RegOp):
class DataType: class DataType:
""" """
Various combinations of dtype and format. Ascend ops various combinations of dtype and format.
The current list below may be incomplete. Please add it if necessary. The current list below may be incomplete.
Please add it if necessary.
current support:
None_None = ("", "")
None_Default = ("", "DefaultFormat")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
BOOL_FracZ = ("bool", "FracZ")
BOOL_FracNZ = ("bool", "FRACTAL_NZ")
BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN")
BOOL_NDHWC = ("bool", "NDHWC")
BOOL_ChannelLast = ("bool", "ChannelLast")
I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat")
I8_5HD = ("int8", "NC1HWC0")
I8_FracZ = ("int8", "FracZ")
I8_FracNZ = ("int8", "FRACTAL_NZ")
I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
I8_ChannelLast = ("int8", "ChannelLast")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
U8_5HD = ("uint8", "NC1HWC0")
U8_FracZ = ("uint8", "FracZ")
U8_FracNZ = ("uint8", "FRACTAL_NZ")
U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
U8_ChannelLast = ("uint8", "ChannelLast")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
I16_5HD = ("int16", "NC1HWC0")
I16_FracZ = ("int16", "FracZ")
I16_FracNZ = ("int16", "FRACTAL_NZ")
I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN")
I16_NDHWC = ("int16", "NDHWC")
I16_ChannelLast = ("int16", "ChannelLast")
U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat")
U16_5HD = ("uint16", "NC1HWC0")
U16_FracZ = ("uint16", "FracZ")
U16_FracNZ = ("uint16", "FRACTAL_NZ")
U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN")
U16_NDHWC = ("uint16", "NDHWC")
U16_ChannelLast = ("uint16", "ChannelLast")
I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat")
I32_5HD = ("int32", "NC1HWC0")
I32_FracZ = ("int32", "FracZ")
I32_FracNZ = ("int32", "FRACTAL_NZ")
I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN")
I32_NDHWC = ("int32", "NDHWC")
I32_ChannelLast = ("int32", "ChannelLast")
U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat")
U32_5HD = ("uint32", "NC1HWC0")
U32_FracZ = ("uint32", "FracZ")
U32_FracNZ = ("uint32", "FRACTAL_NZ")
U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN")
U32_NDHWC = ("uint32", "NDHWC")
U32_ChannelLast = ("uint32", "ChannelLast")
I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat")
I64_5HD = ("int64", "NC1HWC0")
I64_FracZ = ("int64", "FracZ")
I64_FracNZ = ("int64", "FRACTAL_NZ")
I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN")
I64_NDHWC = ("int64", "NDHWC")
I64_ChannelLast = ("int64", "ChannelLast")
U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat")
U64_5HD = ("uint64", "NC1HWC0")
U64_FracZ = ("uint64", "FracZ")
U64_FracNZ = ("uint64", "FRACTAL_NZ")
U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN")
U64_NDHWC = ("uint64", "NDHWC")
U64_ChannelLast = ("uint64", "ChannelLast")
F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat")
F16_5HD = ("float16", "NC1HWC0")
F16_FracZ = ("float16", "FracZ")
F16_FracNZ = ("float16", "FRACTAL_NZ")
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F16_NCDHW = ("float16", "NCDHW")
F16_DHWCN = ("float16", "DHWCN")
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F16_ChannelLast = ("float16", "ChannelLast")
F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
F32_5HD = ("float32", "NC1HWC0")
F32_FracZ = ("float32", "FracZ")
F32_FracNZ = ("float32", "FRACTAL_NZ")
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F32_NCDHW = ("float32", "NCDHW")
F32_DHWCN = ("float32", "DHWCN")
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F32_ChannelLast = ("float32", "ChannelLast")
F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")
F64_ChannelLast = ("float64", "ChannelLast")
""" """
None_None = ("", "") None_None = ("", "")