add example for reg ops

This commit is contained in:
lianliguang 2021-07-07 11:41:14 +08:00
parent 7ff4909f61
commit 2c836f0dec
1 changed files with 199 additions and 5 deletions

View File

@ -15,10 +15,12 @@
"""Operators info register."""
import os
import json
import inspect
import json
import os
from mindspore._c_expression import Oplib
from mindspore._checkparam import Validator as validator
# path of built-in op info register.
@ -37,9 +39,32 @@ def op_info_register(op_info):
Args:
op_info (str or dict): operator information in json format.
Examples:
>>> from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
>>> abs_op_info = TBERegOp("Abs") \
... .fusion_type("ELEMWISE") \
... .async_flag(False) \
... .binfile_name("abs.so") \
... .compute_cost(10) \
... .kernel_name("abs") \
... .partial_flag(True) \
... .op_pattern("formatAgnostic") \
... .input(0, "x", None, "required", None) \
... .output(0, "y", True, "required", "all") \
... .dtype_format(DataType.F16_None, DataType.F16_None) \
... .dtype_format(DataType.F32_None, DataType.F32_None) \
... .dtype_format(DataType.I32_None, DataType.I32_None) \
... .get_op_info()
>>>
>>> @op_info_register(abs_op_info)
... def _abs_tbe():
... return
...
Returns:
Function, returns a decorator for op info register.
"""
def register_decorator(func):
if isinstance(op_info, dict):
op_info_real = json.dumps(op_info)
@ -58,7 +83,9 @@ def op_info_register(op_info):
def wrapped_function(*args, **kwargs):
return func(*args, **kwargs)
return wrapped_function
return register_decorator
@ -329,12 +356,14 @@ class AkgRegOp(RegOp):
class AkgGpuRegOp(AkgRegOp):
"""Class for AkgGpu op info register"""
def __init__(self, op_name):
super(AkgGpuRegOp, self).__init__(op_name, "CUDA")
class AkgAscendRegOp(AkgRegOp):
"""Class for AkgAscend op info register"""
def __init__(self, op_name):
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
@ -348,7 +377,12 @@ class AiCPURegOp(CpuRegOp):
class TBERegOp(RegOp):
"""Class for TBE operator information register."""
"""
Class for TBE operator information register.
Args:
op_name (string):kernel name.
"""
def __init__(self, op_name):
super(TBERegOp, self).__init__(op_name)
@ -535,9 +569,169 @@ class TBERegOp(RegOp):
class DataType:
"""
Various combinations of dtype and format.
Ascend ops various combinations of dtype and format.
The current list below may be incomplete. Please add it if necessary.
The current list below may be incomplete.
Please add it if necessary.
current support:
None_None = ("", "")
None_Default = ("", "DefaultFormat")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
BOOL_FracZ = ("bool", "FracZ")
BOOL_FracNZ = ("bool", "FRACTAL_NZ")
BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN")
BOOL_NDHWC = ("bool", "NDHWC")
BOOL_ChannelLast = ("bool", "ChannelLast")
I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat")
I8_5HD = ("int8", "NC1HWC0")
I8_FracZ = ("int8", "FracZ")
I8_FracNZ = ("int8", "FRACTAL_NZ")
I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
I8_ChannelLast = ("int8", "ChannelLast")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
U8_5HD = ("uint8", "NC1HWC0")
U8_FracZ = ("uint8", "FracZ")
U8_FracNZ = ("uint8", "FRACTAL_NZ")
U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
U8_ChannelLast = ("uint8", "ChannelLast")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
I16_5HD = ("int16", "NC1HWC0")
I16_FracZ = ("int16", "FracZ")
I16_FracNZ = ("int16", "FRACTAL_NZ")
I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN")
I16_NDHWC = ("int16", "NDHWC")
I16_ChannelLast = ("int16", "ChannelLast")
U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat")
U16_5HD = ("uint16", "NC1HWC0")
U16_FracZ = ("uint16", "FracZ")
U16_FracNZ = ("uint16", "FRACTAL_NZ")
U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN")
U16_NDHWC = ("uint16", "NDHWC")
U16_ChannelLast = ("uint16", "ChannelLast")
I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat")
I32_5HD = ("int32", "NC1HWC0")
I32_FracZ = ("int32", "FracZ")
I32_FracNZ = ("int32", "FRACTAL_NZ")
I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN")
I32_NDHWC = ("int32", "NDHWC")
I32_ChannelLast = ("int32", "ChannelLast")
U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat")
U32_5HD = ("uint32", "NC1HWC0")
U32_FracZ = ("uint32", "FracZ")
U32_FracNZ = ("uint32", "FRACTAL_NZ")
U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN")
U32_NDHWC = ("uint32", "NDHWC")
U32_ChannelLast = ("uint32", "ChannelLast")
I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat")
I64_5HD = ("int64", "NC1HWC0")
I64_FracZ = ("int64", "FracZ")
I64_FracNZ = ("int64", "FRACTAL_NZ")
I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN")
I64_NDHWC = ("int64", "NDHWC")
I64_ChannelLast = ("int64", "ChannelLast")
U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat")
U64_5HD = ("uint64", "NC1HWC0")
U64_FracZ = ("uint64", "FracZ")
U64_FracNZ = ("uint64", "FRACTAL_NZ")
U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN")
U64_NDHWC = ("uint64", "NDHWC")
U64_ChannelLast = ("uint64", "ChannelLast")
F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat")
F16_5HD = ("float16", "NC1HWC0")
F16_FracZ = ("float16", "FracZ")
F16_FracNZ = ("float16", "FRACTAL_NZ")
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F16_NCDHW = ("float16", "NCDHW")
F16_DHWCN = ("float16", "DHWCN")
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F16_ChannelLast = ("float16", "ChannelLast")
F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
F32_5HD = ("float32", "NC1HWC0")
F32_FracZ = ("float32", "FracZ")
F32_FracNZ = ("float32", "FRACTAL_NZ")
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F32_NCDHW = ("float32", "NCDHW")
F32_DHWCN = ("float32", "DHWCN")
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F32_ChannelLast = ("float32", "ChannelLast")
F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")
F64_ChannelLast = ("float64", "ChannelLast")
"""
None_None = ("", "")