!116 Optimize operator information register

Merge pull request !116 from zjun/master
This commit is contained in:
mindspore-ci-bot 2020-04-03 16:43:47 +08:00 committed by Gitee
commit d4e51c8f6e
6 changed files with 536 additions and 216 deletions

View File

@ -61,6 +61,7 @@ class OpIOInfo {
std::string name() const { return name_; } std::string name() const { return name_; }
bool need_compile() const { return need_compile_; } bool need_compile() const { return need_compile_; }
std::string param_type() const { return param_type_; } std::string param_type() const { return param_type_; }
std::string reshape_type() const { return reshape_type_; }
std::string shape() const { return shape_; } std::string shape() const { return shape_; }
std::vector<std::string> dtypes() const { return dtypes_; } std::vector<std::string> dtypes() const { return dtypes_; }
std::vector<std::string> formats() const { return formats_; } std::vector<std::string> formats() const { return formats_; }
@ -69,6 +70,7 @@ class OpIOInfo {
void set_name(const std::string& name) { name_ = name; } void set_name(const std::string& name) { name_ = name; }
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
void set_param_type(const std::string& param_type) { param_type_ = param_type; } void set_param_type(const std::string& param_type) { param_type_ = param_type; }
void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; }
void set_shape(const std::string& shape) { shape_ = shape; } void set_shape(const std::string& shape) { shape_ = shape; }
void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; } void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; }
void set_formats(const std::vector<std::string>& formats) { formats_ = formats; } void set_formats(const std::vector<std::string>& formats) { formats_ = formats; }
@ -78,6 +80,7 @@ class OpIOInfo {
std::string name_; std::string name_;
bool need_compile_ = false; bool need_compile_ = false;
std::string param_type_; std::string param_type_;
std::string reshape_type_;
std::string shape_; std::string shape_;
std::vector<std::string> dtypes_; std::vector<std::string> dtypes_;
std::vector<std::string> formats_; std::vector<std::string> formats_;
@ -96,6 +99,8 @@ class OpInfo {
int compute_cost() const { return compute_cost_; } int compute_cost() const { return compute_cost_; }
std::string kernel_name() const { return kernel_name_; } std::string kernel_name() const { return kernel_name_; }
bool partial_flag() const { return partial_flag_; } bool partial_flag() const { return partial_flag_; }
bool dynamic_format() const { return dynamic_format_; }
std::string op_pattern() const { return op_pattern_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
@ -110,6 +115,8 @@ class OpInfo {
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); } void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); } void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); } void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); }
@ -129,6 +136,8 @@ class OpInfo {
int compute_cost_ = 0; int compute_cost_ = 0;
std::string kernel_name_; std::string kernel_name_;
bool partial_flag_ = false; bool partial_flag_ = false;
bool dynamic_format_ = false;
std::string op_pattern_;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;

View File

@ -26,18 +26,22 @@ namespace mindspore {
namespace kernel { namespace kernel {
constexpr auto kImplyType = "imply_type"; constexpr auto kImplyType = "imply_type";
constexpr auto kOpName = "op_name"; constexpr auto kOpName = "op_name";
constexpr auto kTbe = "TBE";
constexpr auto kAkg = "akg";
constexpr auto kAutodiff = "AutoDiff";
constexpr auto kFusionType = "fusion_type"; constexpr auto kFusionType = "fusion_type";
constexpr auto kAsyncFlag = "async_flag"; constexpr auto kAsyncFlag = "async_flag";
constexpr auto kBinfileName = "binfile_name"; constexpr auto kBinfileName = "binfile_name";
constexpr auto kComputeCost = "compute_cost"; constexpr auto kComputeCost = "compute_cost";
constexpr auto kKernelName = "kernel_name"; constexpr auto kKernelName = "kernel_name";
constexpr auto kPartialFlag = "partial_flag"; constexpr auto kPartialFlag = "partial_flag";
constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kDynamicFormat = "dynamic_format";
constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kAttr = "attr"; constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs"; constexpr auto kIputs = "inputs";
constexpr auto kOutputs = "outputs"; constexpr auto kOutputs = "outputs";
constexpr auto kTbe = "TBE";
constexpr auto kAkg = "akg";
constexpr auto kAutodiff = "AutoDiff";
constexpr auto kName = "name"; constexpr auto kName = "name";
constexpr auto kParamType = "param_type"; constexpr auto kParamType = "param_type";
constexpr auto kDtype = "dtype"; constexpr auto kDtype = "dtype";
@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
op_info->set_op_name(obj.at(kOpName)); op_info->set_op_name(obj.at(kOpName));
op_info->set_imply_type(imply_type);
op_info->set_impl_path(impl_path); op_info->set_impl_path(impl_path);
op_info->set_imply_type(imply_type);
op_info->set_fusion_type(obj.at(kFusionType)); op_info->set_fusion_type(obj.at(kFusionType));
if (imply_type == kTBE) { if (imply_type == kTBE) {
op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_async_flag(obj.at(kAsyncFlag));
@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
op_info->set_compute_cost(obj.at(kComputeCost)); op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag)); op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
} }
auto attrs = obj.at(kAttr); auto attrs = obj.at(kAttr);
for (const auto& attr : attrs) { for (const auto& attr : attrs) {
@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
return false; return false;
} }
} }
nlohmann::json dtype_format;
if (obj.find(kDtypeFormat) != obj.end()) {
dtype_format = obj.at(kDtypeFormat);
}
auto inputs = obj.at(kIputs); auto inputs = obj.at(kIputs);
for (const auto& input : inputs) { for (const auto& input : inputs) {
if (!DecodeInputOutput(input, imply_type, kInput, op_info)) { if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed"; MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false; return false;
} }
} }
auto outputs = obj.at(kOutputs); auto outputs = obj.at(kOutputs);
for (const auto& output : outputs) { for (const auto& output : outputs) {
if (!DecodeInputOutput(output, imply_type, kOutput, op_info)) { if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed"; MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false; return false;
} }
@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
return ret; return ret;
} }
bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
size_t index) {
bool ret = true;
try {
std::vector<std::string> dtype;
std::vector<std::string> format;
for (const auto& it : dtype_format) {
dtype.emplace_back(it[index][0]);
format.emplace_back(it[index][1]);
}
op_io->set_dtypes(dtype);
op_io->set_formats(format);
} catch (const std::exception& e) {
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
ret = false;
}
return ret;
}
bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info) { const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) {
bool ret = true; bool ret = true;
try { try {
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
MS_EXCEPTION_IF_NULL(op_io); MS_EXCEPTION_IF_NULL(op_io);
op_io->set_index(obj.at(kIndex)); op_io->set_index(obj.at(kIndex));
op_io->set_name(obj.at(kName)); op_io->set_name(obj.at(kName));
op_io->set_dtypes(obj.at(kDtype)); if (!dtype_format.empty()) {
op_io->set_formats(obj.at(kFormat)); if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) {
MS_LOG(ERROR) << "Decode dtype format failed";
return false;
}
} else {
op_io->set_dtypes(obj.at(kDtype));
op_io->set_formats(obj.at(kFormat));
}
if (op_io->dtypes().size() != op_io->formats().size()) { if (op_io->dtypes().size() != op_io->formats().size()) {
MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes() MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes()
<< "is not equal to format size:" << op_io->formats(); << "is not equal to format size:" << op_io->formats();
@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
if (obj.find(kShape) != obj.end()) { if (obj.find(kShape) != obj.end()) {
op_io->set_shape(obj.at(kShape)); op_io->set_shape(obj.at(kShape));
} }
if (obj.find(kReshapeType) != obj.end()) {
op_io->set_reshape_type(obj.at(kReshapeType));
}
} }
if (io_type == kInput) { if (io_type == kInput) {

View File

@ -38,8 +38,10 @@ class OpLib {
static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path);
static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo>& op_info); const std::shared_ptr<OpInfo>& op_info);
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
size_t index);
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info); const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info); static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);
static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info); static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info);
}; };

View File

@ -30,7 +30,7 @@ Note:
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register from .op_info_register import op_info_register, TBERegOp, DataType
from .primitive import constexpr from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind from .._c_expression import signature_rw, signature_kind
@ -40,6 +40,6 @@ __primitive__ = [
] ]
__all__ = ["get_vm_impl_fn", "vm_impl_registry", __all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "op_info_register", "TBERegOp", "DataType",
"constexpr"] "constexpr"]
__all__.extend(__primitive__) __all__.extend(__primitive__)

View File

@ -14,208 +14,41 @@
# ============================================================================ # ============================================================================
"""AdamApplyOneWithDecay op""" """AdamApplyOneWithDecay op"""
from mindspore.ops.op_info_register import op_info_register from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
adam_apply_one_with_decay_op_info = TBERegOp("AdamApplyOneWithDecay") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("adam_apply_one_with_decay.so") \
.compute_cost(10) \
.kernel_name("adam_apply_one_with_decay") \
.partial_flag(True) \
.input(0, "input0", False, "required", "all") \
.input(1, "input1", False, "required", "all") \
.input(2, "input2", False, "required", "all") \
.input(3, "input3", False, "required", "all") \
.input(4, "input4", False, "required", "all") \
.input(5, "mul0_x", False, "required", "all") \
.input(6, "mul1_x", False, "required", "all") \
.input(7, "mul2_x", False, "required", "all") \
.input(8, "mul3_x", False, "required", "all") \
.input(9, "mul4_x", False, "required", "all") \
.input(10, "add2_y", False, "required", "all") \
.output(0, "output0", False, "required", "all") \
.output(1, "output1", False, "required", "all") \
.output(2, "output2", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register("""{ @op_info_register(adam_apply_one_with_decay_op_info)
"op_name": "AdamApplyOneWithDecay",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "adam_apply_one_with_decay.so",
"compute_cost": 10,
"kernel_name": "adam_apply_one_with_decay",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input0",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input3",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input4",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul0_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 6,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul1_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 7,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul2_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 8,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul3_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 9,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul4_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 10,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "add2_y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output0",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output1",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output2",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
def _adam_apply_one_with_decay_tbe(): def _adam_apply_one_with_decay_tbe():
"""AdamApplyOneWithDecay TBE register""" """AdamApplyOneWithDecay TBE register"""
return return

View File

@ -16,6 +16,7 @@
"""Operators info register.""" """Operators info register."""
import os import os
import json
import inspect import inspect
from mindspore._c_expression import Oplib from mindspore._c_expression import Oplib
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
@ -32,21 +33,453 @@ def op_info_register(op_info):
'op_info' must be a str of json format represent the op info, the op info will be added into oplib. 'op_info' must be a str of json format represent the op info, the op info will be added into oplib.
Args: Args:
op_info (str): op info of json format. op_info (str or dict): op info of json format.
Returns: Returns:
Function, returns a decorator for op info register. Function, returns a decorator for op info register.
""" """
def register_decorator(func): def register_decorator(func):
validator.check_type("op_info", op_info, [str]) if isinstance(op_info, dict):
op_info_real = json.dumps(op_info)
else:
op_info_real = op_info
validator.check_type("op_info", op_info_real, [str])
op_lib = Oplib() op_lib = Oplib()
file_path = os.path.realpath(inspect.getfile(func)) file_path = os.path.realpath(inspect.getfile(func))
# keep the path custom ops implementation. # keep the path custom ops implementation.
imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
if not op_lib.reg_op(op_info, imply_path): if not op_lib.reg_op(op_info_real, imply_path):
raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info)) raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real))
def wrapped_function(*args, **kwargs): def wrapped_function(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapped_function return wrapped_function
return register_decorator return register_decorator
class RegOp():
"""
Base class for op info register.
Args:
op_name (str): Name of op.
inputs (list): Inputs inoformation of the op.
outputs (list): Outputs information of the op.
attr_ (list): Attribute information of the op.
dtype_format_ (list): Dtype and format information of the op.
"""
def __init__(self, op_name=""):
if not isinstance(op_name, str):
raise ValueError("op name value must be string")
if not op_name.strip():
raise ValueError("op name is empty")
self.op_name = op_name
self.inputs = []
self.outputs = []
self.attr_ = []
self.dtype_format_ = []
def is_string(self, value):
"""
Check if the value is a str type.
Args:
value: Parameter to to check.
Raises:
TypeError: If the type of value is not a str.
"""
if not isinstance(value, str):
raise TypeError("%s value must be str" % str(value))
def is_int(self, value):
"""
Check if the value is a int.
Args:
value: Parameter to to check.
Raises:
TypeError: If the type of value is not a int.
"""
if not isinstance(value, int):
raise TypeError("%s value must be int" % str(value))
def is_bool(self, value):
"""
Check if the value is a bool.
Args:
value: Parameter to to check.
Raises:
TypeError: If the type of value is not a bool.
"""
if not isinstance(value, bool):
raise TypeError("%s value must be bool" % str(value))
def dtype_format(self, *args):
"""
Register dtype and format.
Args:
args (tuple): Value of dtype and format.
Raises:
ValueError: If the size of args not equal to input size add output size.
TypeError: If the type of args is not tuple.
"""
if len(self.inputs) + len(self.outputs) != len(args):
raise ValueError("input size add output size must be equal to detype format size")
dtype_format = []
for arg in args:
if not isinstance(arg, tuple) or len(arg) != 2:
raise ValueError("dtype and format value must be tuple of two elements")
self.is_string(arg[0])
self.is_string(arg[1])
dtype_format.append(arg)
self.dtype_format_.append(tuple(dtype_format))
return self
def get_op_info(self):
"""
Return all registration information for this instance.
The '_' character ending the key is removed here for compatibility with previous version.
Key will be unified into an underlined form later.
"""
op_info = {}
for key, value in self.__dict__.items():
if isinstance(key, str) and key.endswith('_'):
op_info[key.rstrip('_')] = value
else:
op_info[key] = value
return op_info
class TBERegOp(RegOp):
"""Class for TBE op info register."""
def __init__(self, op_name=""):
super(TBERegOp, self).__init__(op_name)
self.imply_type = "TBE"
self.fusion_type_ = ''
self.async_flag_ = False
self.binfile_name_ = ''
self.compute_cost_ = 10
self.kernel_name_ = ''
self.partial_flag_ = False
self.reshape_type_ = ''
self.dynamic_format_ = False
self.op_pattern_ = ""
def fusion_type(self, fusion_type):
"""
Register fusion type.
Args:
fusion_type (str): Value of fusion type.
"""
self.is_string(fusion_type)
self.fusion_type_ = fusion_type
return self
def async_flag(self, async_flag):
"""
Register async flag.
Args:
async_flag (bool): Value of async flag.
"""
self.is_bool(async_flag)
self.async_flag_ = async_flag
return self
def binfile_name(self, binfile_name):
"""
Register binfile name.
Args:
binfile_name (str): Name of op binfile.
"""
self.is_string(binfile_name)
self.binfile_name_ = binfile_name
return self
def compute_cost(self, compute_cost):
"""
Register compute cost.
Args:
compute_cost (int): Value of compute cost.
"""
self.is_int(compute_cost)
self.compute_cost_ = compute_cost
return self
def kernel_name(self, kernel_name):
"""
Register kernel name.
Args:
kernel_name (str): Name of op kernel.
"""
self.is_string(kernel_name)
self.kernel_name_ = kernel_name
return self
def partial_flag(self, partial_flag):
"""
Register partial flag.
Args:
partial_flag (bool): Value of partial flag.
"""
self.is_bool(partial_flag)
self.partial_flag_ = partial_flag
return self
def reshape_type(self, reshape_type):
"""
Register reshape type.
Args:
reshape_type (str): Value of reshape type.
"""
self.is_string(reshape_type)
self.reshape_type_ = reshape_type
return self
def dynamic_format(self, dynamic_format):
"""
Register dynamic format.
Args:
reshape_type (bool): Value of dynamic format.
"""
self.is_bool(dynamic_format)
self.dynamic_format_ = dynamic_format
return self
def op_pattern(self, pattern=None):
"""
Register op pattern information.
Args:
pattern (str): Value of op pattern.
"""
if pattern is not None and self.istring(pattern):
self.op_pattern_ = pattern
return self
def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs):
"""
Register op attribute information.
Args:
name (str): Name of the attribute. Default: None.
param_type (str): Param type of the attribute. Default: None.
type (str): Type of the attribute. Default: None.
value (str): Value of the attribute. Default: None.
default_value (str): Default value of attribute. Default: None.
kwargs (dict): Other information for the attribute.
"""
param_list = [name, param_type, value_type, value, default_value]
attr_dict = {}
for index, element in enumerate(param_list):
if element is not None:
self.is_string(element)
if index == 0:
attr_dict["name"] = element
elif index == 1:
attr_dict["param_type"] = element
elif index == 2:
attr_dict["type"] = element
elif index == 3:
attr_dict["value"] = element
elif index == 4:
attr_dict["default_value"] = element
if kwargs:
attr_dict = dict(attr_dict, **kwargs)
self.attr_.append(attr_dict)
return self
def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
"""
Register op input information.
Args:
index (int): Order of the input. Default: None.
name (str): Name of the input. Default: None.
need_compile (bool): The input need compile whether or not. Default: None.
param_type (str): Type of the input. Default: None.
shape (str): Shape of the input. Default: None.
kwargs (dict): Other information for the input.
"""
param_list = [index, name, need_compile, param_type, shape]
input_dict = {}
for idx, element in enumerate(param_list):
if element is not None:
if idx == 0:
self.is_int(element)
input_dict["index"] = element
elif idx == 1:
self.is_string(element)
input_dict["name"] = element
elif idx == 2:
self.is_bool(element)
input_dict["need_compile"] = element
elif idx == 3:
self.is_string(element)
input_dict["param_type"] = element
elif idx == 4:
self.is_string(element)
input_dict["shape"] = element
if kwargs:
input_dict = dict(input_dict, **kwargs)
self.inputs.append(input_dict)
return self
def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
"""
Register op output information.
Args:
index (int): Order of the output. Default: None.
name (str): Name of the output. Default: None.
need_compile (bool): The output need compile whether or not. Default: None.
param_type (str): Type of the output. Default: None.
shape (str): Shape of the output. Default: None.
kwargs (dict): Other information for the output.
"""
param_list = [index, name, need_compile, param_type, shape]
output_dict = {}
for idx, element in enumerate(param_list):
if element is not None:
if idx == 0:
self.is_int(element)
output_dict["index"] = element
elif idx == 1:
self.is_string(element)
output_dict["name"] = element
elif idx == 2:
self.is_bool(element)
output_dict["need_compile"] = element
elif idx == 3:
self.is_string(element)
output_dict["param_type"] = element
elif idx == 4:
self.is_string(element)
output_dict["shape"] = element
if kwargs:
output_dict = dict(output_dict, **kwargs)
self.outputs.append(output_dict)
return self
class DataType():
"""
Various combinations of dtype and formatself.
The current list below maybe not completed. If necessary, please add it.
"""
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN")
I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat")
I8_5HD = ("int8", "NC1HWC0")
I8_FracZ = ("int8", "Fracz")
I8_FracNZ = ("int8", "FRACTAL_NZ")
I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
U8_5HD = ("uint8", "NC1HWC0")
U8_FracZ = ("uint8", "Fracz")
U8_FracNZ = ("uint8", "FRACTAL_NZ")
U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
I16_5HD = ("int16", "NC1HWC0")
I16_FracZ = ("int16", "Fracz")
I16_FracNZ = ("int16", "FRACTAL_NZ")
I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN")
U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat")
U16_5HD = ("uint16", "NC1HWC0")
U16_FracZ = ("uint16", "Fracz")
U16_FracNZ = ("uint16", "FRACTAL_NZ")
U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN")
I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat")
I32_5HD = ("int32", "NC1HWC0")
I32_FracZ = ("int32", "Fracz")
I32_FracNZ = ("int32", "FRACTAL_NZ")
I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN")
U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat")
U32_5HD = ("uint32", "NC1HWC0")
U32_FracZ = ("uint32", "Fracz")
U32_FracNZ = ("uint32", "FRACTAL_NZ")
U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN")
I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat")
I64_5HD = ("int64", "NC1HWC0")
I64_FracZ = ("int64", "Fracz")
I64_FracNZ = ("int64", "FRACTAL_NZ")
I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN")
U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat")
U64_5HD = ("uint64", "NC1HWC0")
U64_FracZ = ("uint64", "Fracz")
U64_FracNZ = ("uint64", "FRACTAL_NZ")
U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN")
F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat")
F16_5HD = ("float16", "NC1HWC0")
F16_FracZ = ("float16", "Fracz")
F16_FracNZ = ("float16", "FRACTAL_NZ")
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
F32_5HD = ("float32", "NC1HWC0")
F32_FracZ = ("float32", "Fracz")
F32_FracNZ = ("float32", "FRACTAL_NZ")
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")