forked from mindspore-Ecosystem/mindspore
add example for reg ops
This commit is contained in:
parent
7ff4909f61
commit
2c836f0dec
|
@ -15,10 +15,12 @@
|
||||||
|
|
||||||
"""Operators info register."""
|
"""Operators info register."""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
from mindspore._c_expression import Oplib
|
from mindspore._c_expression import Oplib
|
||||||
|
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
|
|
||||||
# path of built-in op info register.
|
# path of built-in op info register.
|
||||||
|
@ -37,9 +39,32 @@ def op_info_register(op_info):
|
||||||
Args:
|
Args:
|
||||||
op_info (str or dict): operator information in json format.
|
op_info (str or dict): operator information in json format.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
>>> abs_op_info = TBERegOp("Abs") \
|
||||||
|
... .fusion_type("ELEMWISE") \
|
||||||
|
... .async_flag(False) \
|
||||||
|
... .binfile_name("abs.so") \
|
||||||
|
... .compute_cost(10) \
|
||||||
|
... .kernel_name("abs") \
|
||||||
|
... .partial_flag(True) \
|
||||||
|
... .op_pattern("formatAgnostic") \
|
||||||
|
... .input(0, "x", None, "required", None) \
|
||||||
|
... .output(0, "y", True, "required", "all") \
|
||||||
|
... .dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
|
... .dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
|
... .dtype_format(DataType.I32_None, DataType.I32_None) \
|
||||||
|
... .get_op_info()
|
||||||
|
>>>
|
||||||
|
>>> @op_info_register(abs_op_info)
|
||||||
|
... def _abs_tbe():
|
||||||
|
... return
|
||||||
|
...
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Function, returns a decorator for op info register.
|
Function, returns a decorator for op info register.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register_decorator(func):
|
def register_decorator(func):
|
||||||
if isinstance(op_info, dict):
|
if isinstance(op_info, dict):
|
||||||
op_info_real = json.dumps(op_info)
|
op_info_real = json.dumps(op_info)
|
||||||
|
@ -58,7 +83,9 @@ def op_info_register(op_info):
|
||||||
|
|
||||||
def wrapped_function(*args, **kwargs):
|
def wrapped_function(*args, **kwargs):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapped_function
|
return wrapped_function
|
||||||
|
|
||||||
return register_decorator
|
return register_decorator
|
||||||
|
|
||||||
|
|
||||||
|
@ -329,12 +356,14 @@ class AkgRegOp(RegOp):
|
||||||
|
|
||||||
class AkgGpuRegOp(AkgRegOp):
|
class AkgGpuRegOp(AkgRegOp):
|
||||||
"""Class for AkgGpu op info register"""
|
"""Class for AkgGpu op info register"""
|
||||||
|
|
||||||
def __init__(self, op_name):
|
def __init__(self, op_name):
|
||||||
super(AkgGpuRegOp, self).__init__(op_name, "CUDA")
|
super(AkgGpuRegOp, self).__init__(op_name, "CUDA")
|
||||||
|
|
||||||
|
|
||||||
class AkgAscendRegOp(AkgRegOp):
|
class AkgAscendRegOp(AkgRegOp):
|
||||||
"""Class for AkgAscend op info register"""
|
"""Class for AkgAscend op info register"""
|
||||||
|
|
||||||
def __init__(self, op_name):
|
def __init__(self, op_name):
|
||||||
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
|
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
|
||||||
|
|
||||||
|
@ -348,7 +377,12 @@ class AiCPURegOp(CpuRegOp):
|
||||||
|
|
||||||
|
|
||||||
class TBERegOp(RegOp):
|
class TBERegOp(RegOp):
|
||||||
"""Class for TBE operator information register."""
|
"""
|
||||||
|
Class for TBE operator information register.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_name (string):kernel name.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, op_name):
|
def __init__(self, op_name):
|
||||||
super(TBERegOp, self).__init__(op_name)
|
super(TBERegOp, self).__init__(op_name)
|
||||||
|
@ -535,9 +569,169 @@ class TBERegOp(RegOp):
|
||||||
|
|
||||||
class DataType:
|
class DataType:
|
||||||
"""
|
"""
|
||||||
Various combinations of dtype and format.
|
Ascend ops various combinations of dtype and format.
|
||||||
|
|
||||||
The current list below may be incomplete. Please add it if necessary.
|
The current list below may be incomplete.
|
||||||
|
|
||||||
|
Please add it if necessary.
|
||||||
|
|
||||||
|
current support:
|
||||||
|
|
||||||
|
None_None = ("", "")
|
||||||
|
None_Default = ("", "DefaultFormat")
|
||||||
|
BOOL_None = ("bool", "")
|
||||||
|
BOOL_Default = ("bool", "DefaultFormat")
|
||||||
|
BOOL_5HD = ("bool", "NC1HWC0")
|
||||||
|
BOOL_FracZ = ("bool", "FracZ")
|
||||||
|
BOOL_FracNZ = ("bool", "FRACTAL_NZ")
|
||||||
|
BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
|
||||||
|
BOOL_NCHW = ("bool", "NCHW")
|
||||||
|
BOOL_NHWC = ("bool", "NHWC")
|
||||||
|
BOOL_HWCN = ("bool", "HWCN")
|
||||||
|
BOOL_NDHWC = ("bool", "NDHWC")
|
||||||
|
BOOL_ChannelLast = ("bool", "ChannelLast")
|
||||||
|
|
||||||
|
I8_None = ("int8", "")
|
||||||
|
I8_Default = ("int8", "DefaultFormat")
|
||||||
|
I8_5HD = ("int8", "NC1HWC0")
|
||||||
|
I8_FracZ = ("int8", "FracZ")
|
||||||
|
I8_FracNZ = ("int8", "FRACTAL_NZ")
|
||||||
|
I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
|
||||||
|
I8_NCHW = ("int8", "NCHW")
|
||||||
|
I8_NHWC = ("int8", "NHWC")
|
||||||
|
I8_HWCN = ("int8", "HWCN")
|
||||||
|
I8_NDHWC = ("int8", "NDHWC")
|
||||||
|
I8_ChannelLast = ("int8", "ChannelLast")
|
||||||
|
|
||||||
|
U8_None = ("uint8", "")
|
||||||
|
U8_Default = ("uint8", "DefaultFormat")
|
||||||
|
U8_5HD = ("uint8", "NC1HWC0")
|
||||||
|
U8_FracZ = ("uint8", "FracZ")
|
||||||
|
U8_FracNZ = ("uint8", "FRACTAL_NZ")
|
||||||
|
U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
|
||||||
|
U8_NCHW = ("uint8", "NCHW")
|
||||||
|
U8_NHWC = ("uint8", "NHWC")
|
||||||
|
U8_HWCN = ("uint8", "HWCN")
|
||||||
|
U8_NDHWC = ("uint8", "NDHWC")
|
||||||
|
U8_ChannelLast = ("uint8", "ChannelLast")
|
||||||
|
|
||||||
|
I16_None = ("int16", "")
|
||||||
|
I16_Default = ("int16", "DefaultFormat")
|
||||||
|
I16_5HD = ("int16", "NC1HWC0")
|
||||||
|
I16_FracZ = ("int16", "FracZ")
|
||||||
|
I16_FracNZ = ("int16", "FRACTAL_NZ")
|
||||||
|
I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
|
||||||
|
I16_NCHW = ("int16", "NCHW")
|
||||||
|
I16_NHWC = ("int16", "NHWC")
|
||||||
|
I16_HWCN = ("int16", "HWCN")
|
||||||
|
I16_NDHWC = ("int16", "NDHWC")
|
||||||
|
I16_ChannelLast = ("int16", "ChannelLast")
|
||||||
|
|
||||||
|
U16_None = ("uint16", "")
|
||||||
|
U16_Default = ("uint16", "DefaultFormat")
|
||||||
|
U16_5HD = ("uint16", "NC1HWC0")
|
||||||
|
U16_FracZ = ("uint16", "FracZ")
|
||||||
|
U16_FracNZ = ("uint16", "FRACTAL_NZ")
|
||||||
|
U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
|
||||||
|
U16_NCHW = ("uint16", "NCHW")
|
||||||
|
U16_NHWC = ("uint16", "NHWC")
|
||||||
|
U16_HWCN = ("uint16", "HWCN")
|
||||||
|
U16_NDHWC = ("uint16", "NDHWC")
|
||||||
|
U16_ChannelLast = ("uint16", "ChannelLast")
|
||||||
|
|
||||||
|
I32_None = ("int32", "")
|
||||||
|
I32_Default = ("int32", "DefaultFormat")
|
||||||
|
I32_5HD = ("int32", "NC1HWC0")
|
||||||
|
I32_FracZ = ("int32", "FracZ")
|
||||||
|
I32_FracNZ = ("int32", "FRACTAL_NZ")
|
||||||
|
I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
|
||||||
|
I32_NCHW = ("int32", "NCHW")
|
||||||
|
I32_NHWC = ("int32", "NHWC")
|
||||||
|
I32_HWCN = ("int32", "HWCN")
|
||||||
|
I32_NDHWC = ("int32", "NDHWC")
|
||||||
|
I32_ChannelLast = ("int32", "ChannelLast")
|
||||||
|
|
||||||
|
U32_None = ("uint32", "")
|
||||||
|
U32_Default = ("uint32", "DefaultFormat")
|
||||||
|
U32_5HD = ("uint32", "NC1HWC0")
|
||||||
|
U32_FracZ = ("uint32", "FracZ")
|
||||||
|
U32_FracNZ = ("uint32", "FRACTAL_NZ")
|
||||||
|
U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
|
||||||
|
U32_NCHW = ("uint32", "NCHW")
|
||||||
|
U32_NHWC = ("uint32", "NHWC")
|
||||||
|
U32_HWCN = ("uint32", "HWCN")
|
||||||
|
U32_NDHWC = ("uint32", "NDHWC")
|
||||||
|
U32_ChannelLast = ("uint32", "ChannelLast")
|
||||||
|
|
||||||
|
I64_None = ("int64", "")
|
||||||
|
I64_Default = ("int64", "DefaultFormat")
|
||||||
|
I64_5HD = ("int64", "NC1HWC0")
|
||||||
|
I64_FracZ = ("int64", "FracZ")
|
||||||
|
I64_FracNZ = ("int64", "FRACTAL_NZ")
|
||||||
|
I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
|
||||||
|
I64_NCHW = ("int64", "NCHW")
|
||||||
|
I64_NHWC = ("int64", "NHWC")
|
||||||
|
I64_HWCN = ("int64", "HWCN")
|
||||||
|
I64_NDHWC = ("int64", "NDHWC")
|
||||||
|
I64_ChannelLast = ("int64", "ChannelLast")
|
||||||
|
|
||||||
|
U64_None = ("uint64", "")
|
||||||
|
U64_Default = ("uint64", "DefaultFormat")
|
||||||
|
U64_5HD = ("uint64", "NC1HWC0")
|
||||||
|
U64_FracZ = ("uint64", "FracZ")
|
||||||
|
U64_FracNZ = ("uint64", "FRACTAL_NZ")
|
||||||
|
U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
|
||||||
|
U64_NCHW = ("uint64", "NCHW")
|
||||||
|
U64_NHWC = ("uint64", "NHWC")
|
||||||
|
U64_HWCN = ("uint64", "HWCN")
|
||||||
|
U64_NDHWC = ("uint64", "NDHWC")
|
||||||
|
U64_ChannelLast = ("uint64", "ChannelLast")
|
||||||
|
|
||||||
|
F16_None = ("float16", "")
|
||||||
|
F16_Default = ("float16", "DefaultFormat")
|
||||||
|
F16_5HD = ("float16", "NC1HWC0")
|
||||||
|
F16_FracZ = ("float16", "FracZ")
|
||||||
|
F16_FracNZ = ("float16", "FRACTAL_NZ")
|
||||||
|
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
|
||||||
|
F16_NCHW = ("float16", "NCHW")
|
||||||
|
F16_NHWC = ("float16", "NHWC")
|
||||||
|
F16_HWCN = ("float16", "HWCN")
|
||||||
|
F16_NDHWC = ("float16", "NDHWC")
|
||||||
|
F16_NCDHW = ("float16", "NCDHW")
|
||||||
|
F16_DHWCN = ("float16", "DHWCN")
|
||||||
|
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
||||||
|
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
||||||
|
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||||
|
F16_ChannelLast = ("float16", "ChannelLast")
|
||||||
|
|
||||||
|
F32_None = ("float32", "")
|
||||||
|
F32_Default = ("float32", "DefaultFormat")
|
||||||
|
F32_5HD = ("float32", "NC1HWC0")
|
||||||
|
F32_FracZ = ("float32", "FracZ")
|
||||||
|
F32_FracNZ = ("float32", "FRACTAL_NZ")
|
||||||
|
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
|
||||||
|
F32_NCHW = ("float32", "NCHW")
|
||||||
|
F32_NHWC = ("float32", "NHWC")
|
||||||
|
F32_HWCN = ("float32", "HWCN")
|
||||||
|
F32_NDHWC = ("float32", "NDHWC")
|
||||||
|
F32_NCDHW = ("float32", "NCDHW")
|
||||||
|
F32_DHWCN = ("float32", "DHWCN")
|
||||||
|
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
||||||
|
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
||||||
|
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||||
|
F32_ChannelLast = ("float32", "ChannelLast")
|
||||||
|
|
||||||
|
F64_None = ("float64", "")
|
||||||
|
F64_Default = ("float64", "DefaultFormat")
|
||||||
|
F64_5HD = ("float64", "NC1HWC0")
|
||||||
|
F64_FracZ = ("float64", "FracZ")
|
||||||
|
F64_FracNZ = ("float64", "FRACTAL_NZ")
|
||||||
|
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
|
||||||
|
F64_NCHW = ("float64", "NCHW")
|
||||||
|
F64_NHWC = ("float64", "NHWC")
|
||||||
|
F64_HWCN = ("float64", "HWCN")
|
||||||
|
F64_NDHWC = ("float64", "NDHWC")
|
||||||
|
F64_ChannelLast = ("float64", "ChannelLast")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
None_None = ("", "")
|
None_None = ("", "")
|
||||||
|
|
Loading…
Reference in New Issue