!116 Optimize operator information register
Merge pull request !116 from zjun/master
This commit is contained in:
commit
d4e51c8f6e
|
@ -61,6 +61,7 @@ class OpIOInfo {
|
|||
std::string name() const { return name_; }
|
||||
bool need_compile() const { return need_compile_; }
|
||||
std::string param_type() const { return param_type_; }
|
||||
std::string reshape_type() const { return reshape_type_; }
|
||||
std::string shape() const { return shape_; }
|
||||
std::vector<std::string> dtypes() const { return dtypes_; }
|
||||
std::vector<std::string> formats() const { return formats_; }
|
||||
|
@ -69,6 +70,7 @@ class OpIOInfo {
|
|||
void set_name(const std::string& name) { name_ = name; }
|
||||
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
|
||||
void set_param_type(const std::string& param_type) { param_type_ = param_type; }
|
||||
void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; }
|
||||
void set_shape(const std::string& shape) { shape_ = shape; }
|
||||
void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; }
|
||||
void set_formats(const std::vector<std::string>& formats) { formats_ = formats; }
|
||||
|
@ -78,6 +80,7 @@ class OpIOInfo {
|
|||
std::string name_;
|
||||
bool need_compile_ = false;
|
||||
std::string param_type_;
|
||||
std::string reshape_type_;
|
||||
std::string shape_;
|
||||
std::vector<std::string> dtypes_;
|
||||
std::vector<std::string> formats_;
|
||||
|
@ -96,6 +99,8 @@ class OpInfo {
|
|||
int compute_cost() const { return compute_cost_; }
|
||||
std::string kernel_name() const { return kernel_name_; }
|
||||
bool partial_flag() const { return partial_flag_; }
|
||||
bool dynamic_format() const { return dynamic_format_; }
|
||||
std::string op_pattern() const { return op_pattern_; }
|
||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
|
||||
|
@ -110,6 +115,8 @@ class OpInfo {
|
|||
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
|
||||
void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; }
|
||||
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
|
||||
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
|
||||
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
|
||||
void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); }
|
||||
void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); }
|
||||
void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); }
|
||||
|
@ -129,6 +136,8 @@ class OpInfo {
|
|||
int compute_cost_ = 0;
|
||||
std::string kernel_name_;
|
||||
bool partial_flag_ = false;
|
||||
bool dynamic_format_ = false;
|
||||
std::string op_pattern_;
|
||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
|
||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
|
||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;
|
||||
|
|
|
@ -26,18 +26,22 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
constexpr auto kImplyType = "imply_type";
|
||||
constexpr auto kOpName = "op_name";
|
||||
constexpr auto kTbe = "TBE";
|
||||
constexpr auto kAkg = "akg";
|
||||
constexpr auto kAutodiff = "AutoDiff";
|
||||
constexpr auto kFusionType = "fusion_type";
|
||||
constexpr auto kAsyncFlag = "async_flag";
|
||||
constexpr auto kBinfileName = "binfile_name";
|
||||
constexpr auto kComputeCost = "compute_cost";
|
||||
constexpr auto kKernelName = "kernel_name";
|
||||
constexpr auto kPartialFlag = "partial_flag";
|
||||
constexpr auto kReshapeType = "reshape_type";
|
||||
constexpr auto kOpPattern = "op_pattern";
|
||||
constexpr auto kDynamicFormat = "dynamic_format";
|
||||
constexpr auto kDtypeFormat = "dtype_format";
|
||||
constexpr auto kAttr = "attr";
|
||||
constexpr auto kIputs = "inputs";
|
||||
constexpr auto kOutputs = "outputs";
|
||||
constexpr auto kTbe = "TBE";
|
||||
constexpr auto kAkg = "akg";
|
||||
constexpr auto kAutodiff = "AutoDiff";
|
||||
constexpr auto kName = "name";
|
||||
constexpr auto kParamType = "param_type";
|
||||
constexpr auto kDtype = "dtype";
|
||||
|
@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
|||
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
op_info->set_op_name(obj.at(kOpName));
|
||||
op_info->set_imply_type(imply_type);
|
||||
op_info->set_impl_path(impl_path);
|
||||
op_info->set_imply_type(imply_type);
|
||||
op_info->set_fusion_type(obj.at(kFusionType));
|
||||
if (imply_type == kTBE) {
|
||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||
|
@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
|||
op_info->set_compute_cost(obj.at(kComputeCost));
|
||||
op_info->set_kernel_name(obj.at(kKernelName));
|
||||
op_info->set_partial_flag(obj.at(kPartialFlag));
|
||||
if (obj.find(kOpPattern) != obj.end()) {
|
||||
op_info->set_op_pattern(obj.at(kOpPattern));
|
||||
}
|
||||
if (obj.find(kDynamicFormat) != obj.end()) {
|
||||
op_info->set_dynamic_format(obj.at(kDynamicFormat));
|
||||
}
|
||||
}
|
||||
auto attrs = obj.at(kAttr);
|
||||
for (const auto& attr : attrs) {
|
||||
|
@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
|||
return false;
|
||||
}
|
||||
}
|
||||
nlohmann::json dtype_format;
|
||||
if (obj.find(kDtypeFormat) != obj.end()) {
|
||||
dtype_format = obj.at(kDtypeFormat);
|
||||
}
|
||||
auto inputs = obj.at(kIputs);
|
||||
for (const auto& input : inputs) {
|
||||
if (!DecodeInputOutput(input, imply_type, kInput, op_info)) {
|
||||
if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
|
||||
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto outputs = obj.at(kOutputs);
|
||||
for (const auto& output : outputs) {
|
||||
if (!DecodeInputOutput(output, imply_type, kOutput, op_info)) {
|
||||
if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
|
||||
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
|
||||
return false;
|
||||
}
|
||||
|
@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
|
|||
return ret;
|
||||
}
|
||||
|
||||
bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
|
||||
size_t index) {
|
||||
bool ret = true;
|
||||
try {
|
||||
std::vector<std::string> dtype;
|
||||
std::vector<std::string> format;
|
||||
for (const auto& it : dtype_format) {
|
||||
dtype.emplace_back(it[index][0]);
|
||||
format.emplace_back(it[index][1]);
|
||||
}
|
||||
op_io->set_dtypes(dtype);
|
||||
op_io->set_formats(format);
|
||||
} catch (const std::exception& e) {
|
||||
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
|
||||
ret = false;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
|
||||
const std::shared_ptr<OpInfo>& op_info) {
|
||||
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) {
|
||||
bool ret = true;
|
||||
try {
|
||||
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_io);
|
||||
op_io->set_index(obj.at(kIndex));
|
||||
op_io->set_name(obj.at(kName));
|
||||
op_io->set_dtypes(obj.at(kDtype));
|
||||
op_io->set_formats(obj.at(kFormat));
|
||||
if (!dtype_format.empty()) {
|
||||
if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) {
|
||||
MS_LOG(ERROR) << "Decode dtype format failed";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
op_io->set_dtypes(obj.at(kDtype));
|
||||
op_io->set_formats(obj.at(kFormat));
|
||||
}
|
||||
if (op_io->dtypes().size() != op_io->formats().size()) {
|
||||
MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes()
|
||||
<< "is not equal to format size:" << op_io->formats();
|
||||
|
@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
|
|||
if (obj.find(kShape) != obj.end()) {
|
||||
op_io->set_shape(obj.at(kShape));
|
||||
}
|
||||
if (obj.find(kReshapeType) != obj.end()) {
|
||||
op_io->set_reshape_type(obj.at(kReshapeType));
|
||||
}
|
||||
}
|
||||
|
||||
if (io_type == kInput) {
|
||||
|
|
|
@ -38,8 +38,10 @@ class OpLib {
|
|||
static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path);
|
||||
static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
|
||||
const std::shared_ptr<OpInfo>& op_info);
|
||||
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
|
||||
size_t index);
|
||||
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
|
||||
const std::shared_ptr<OpInfo>& op_info);
|
||||
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
|
||||
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);
|
||||
static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info);
|
||||
};
|
||||
|
|
|
@ -30,7 +30,7 @@ Note:
|
|||
|
||||
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
||||
from .op_info_register import op_info_register
|
||||
from .op_info_register import op_info_register, TBERegOp, DataType
|
||||
from .primitive import constexpr
|
||||
from .._c_expression import signature_rw, signature_kind
|
||||
|
||||
|
@ -40,6 +40,6 @@ __primitive__ = [
|
|||
]
|
||||
|
||||
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
||||
"op_info_register",
|
||||
"op_info_register", "TBERegOp", "DataType",
|
||||
"constexpr"]
|
||||
__all__.extend(__primitive__)
|
||||
|
|
|
@ -14,208 +14,41 @@
|
|||
# ============================================================================
|
||||
|
||||
"""AdamApplyOneWithDecay op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
adam_apply_one_with_decay_op_info = TBERegOp("AdamApplyOneWithDecay") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("adam_apply_one_with_decay.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("adam_apply_one_with_decay") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input0", False, "required", "all") \
|
||||
.input(1, "input1", False, "required", "all") \
|
||||
.input(2, "input2", False, "required", "all") \
|
||||
.input(3, "input3", False, "required", "all") \
|
||||
.input(4, "input4", False, "required", "all") \
|
||||
.input(5, "mul0_x", False, "required", "all") \
|
||||
.input(6, "mul1_x", False, "required", "all") \
|
||||
.input(7, "mul2_x", False, "required", "all") \
|
||||
.input(8, "mul3_x", False, "required", "all") \
|
||||
.input(9, "mul4_x", False, "required", "all") \
|
||||
.input(10, "add2_y", False, "required", "all") \
|
||||
.output(0, "output0", False, "required", "all") \
|
||||
.output(1, "output1", False, "required", "all") \
|
||||
.output(2, "output2", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "AdamApplyOneWithDecay",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "adam_apply_one_with_decay.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "adam_apply_one_with_decay",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input0",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input1",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input2",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input3",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "input4",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "mul0_x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "mul1_x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "mul2_x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "mul3_x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "mul4_x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "add2_y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output0",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output1",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"dtype": [
|
||||
"float16", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output2",
|
||||
"need_compile": true,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
@op_info_register(adam_apply_one_with_decay_op_info)
|
||||
def _adam_apply_one_with_decay_tbe():
|
||||
"""AdamApplyOneWithDecay TBE register"""
|
||||
return
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Operators info register."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import inspect
|
||||
from mindspore._c_expression import Oplib
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
|
@ -32,21 +33,453 @@ def op_info_register(op_info):
|
|||
'op_info' must be a str of json format represent the op info, the op info will be added into oplib.
|
||||
|
||||
Args:
|
||||
op_info (str): op info of json format.
|
||||
op_info (str or dict): op info of json format.
|
||||
|
||||
Returns:
|
||||
Function, returns a decorator for op info register.
|
||||
"""
|
||||
def register_decorator(func):
|
||||
validator.check_type("op_info", op_info, [str])
|
||||
if isinstance(op_info, dict):
|
||||
op_info_real = json.dumps(op_info)
|
||||
else:
|
||||
op_info_real = op_info
|
||||
validator.check_type("op_info", op_info_real, [str])
|
||||
op_lib = Oplib()
|
||||
file_path = os.path.realpath(inspect.getfile(func))
|
||||
# keep the path custom ops implementation.
|
||||
imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
|
||||
if not op_lib.reg_op(op_info, imply_path):
|
||||
raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info))
|
||||
if not op_lib.reg_op(op_info_real, imply_path):
|
||||
raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real))
|
||||
|
||||
def wrapped_function(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapped_function
|
||||
return register_decorator
|
||||
|
||||
|
||||
class RegOp():
|
||||
"""
|
||||
Base class for op info register.
|
||||
|
||||
Args:
|
||||
op_name (str): Name of op.
|
||||
inputs (list): Inputs inoformation of the op.
|
||||
outputs (list): Outputs information of the op.
|
||||
attr_ (list): Attribute information of the op.
|
||||
dtype_format_ (list): Dtype and format information of the op.
|
||||
"""
|
||||
|
||||
def __init__(self, op_name=""):
|
||||
if not isinstance(op_name, str):
|
||||
raise ValueError("op name value must be string")
|
||||
if not op_name.strip():
|
||||
raise ValueError("op name is empty")
|
||||
self.op_name = op_name
|
||||
self.inputs = []
|
||||
self.outputs = []
|
||||
self.attr_ = []
|
||||
self.dtype_format_ = []
|
||||
|
||||
def is_string(self, value):
|
||||
"""
|
||||
Check if the value is a str type.
|
||||
|
||||
Args:
|
||||
value: Parameter to to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not a str.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("%s value must be str" % str(value))
|
||||
|
||||
def is_int(self, value):
|
||||
"""
|
||||
Check if the value is a int.
|
||||
|
||||
Args:
|
||||
value: Parameter to to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not a int.
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("%s value must be int" % str(value))
|
||||
|
||||
def is_bool(self, value):
|
||||
"""
|
||||
Check if the value is a bool.
|
||||
|
||||
Args:
|
||||
value: Parameter to to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not a bool.
|
||||
"""
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("%s value must be bool" % str(value))
|
||||
|
||||
def dtype_format(self, *args):
|
||||
"""
|
||||
Register dtype and format.
|
||||
|
||||
Args:
|
||||
args (tuple): Value of dtype and format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the size of args not equal to input size add output size.
|
||||
TypeError: If the type of args is not tuple.
|
||||
"""
|
||||
if len(self.inputs) + len(self.outputs) != len(args):
|
||||
raise ValueError("input size add output size must be equal to detype format size")
|
||||
dtype_format = []
|
||||
for arg in args:
|
||||
if not isinstance(arg, tuple) or len(arg) != 2:
|
||||
raise ValueError("dtype and format value must be tuple of two elements")
|
||||
self.is_string(arg[0])
|
||||
self.is_string(arg[1])
|
||||
dtype_format.append(arg)
|
||||
self.dtype_format_.append(tuple(dtype_format))
|
||||
return self
|
||||
|
||||
def get_op_info(self):
|
||||
"""
|
||||
Return all registration information for this instance.
|
||||
|
||||
The '_' character ending the key is removed here for compatibility with previous version.
|
||||
|
||||
Key will be unified into an underlined form later.
|
||||
"""
|
||||
op_info = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(key, str) and key.endswith('_'):
|
||||
op_info[key.rstrip('_')] = value
|
||||
else:
|
||||
op_info[key] = value
|
||||
return op_info
|
||||
|
||||
|
||||
class TBERegOp(RegOp):
|
||||
"""Class for TBE op info register."""
|
||||
|
||||
def __init__(self, op_name=""):
|
||||
super(TBERegOp, self).__init__(op_name)
|
||||
self.imply_type = "TBE"
|
||||
self.fusion_type_ = ''
|
||||
self.async_flag_ = False
|
||||
self.binfile_name_ = ''
|
||||
self.compute_cost_ = 10
|
||||
self.kernel_name_ = ''
|
||||
self.partial_flag_ = False
|
||||
self.reshape_type_ = ''
|
||||
self.dynamic_format_ = False
|
||||
self.op_pattern_ = ""
|
||||
|
||||
def fusion_type(self, fusion_type):
|
||||
"""
|
||||
Register fusion type.
|
||||
|
||||
Args:
|
||||
fusion_type (str): Value of fusion type.
|
||||
"""
|
||||
self.is_string(fusion_type)
|
||||
self.fusion_type_ = fusion_type
|
||||
return self
|
||||
|
||||
def async_flag(self, async_flag):
|
||||
"""
|
||||
Register async flag.
|
||||
|
||||
Args:
|
||||
async_flag (bool): Value of async flag.
|
||||
"""
|
||||
self.is_bool(async_flag)
|
||||
self.async_flag_ = async_flag
|
||||
return self
|
||||
|
||||
def binfile_name(self, binfile_name):
|
||||
"""
|
||||
Register binfile name.
|
||||
|
||||
Args:
|
||||
binfile_name (str): Name of op binfile.
|
||||
"""
|
||||
self.is_string(binfile_name)
|
||||
self.binfile_name_ = binfile_name
|
||||
return self
|
||||
|
||||
def compute_cost(self, compute_cost):
|
||||
"""
|
||||
Register compute cost.
|
||||
|
||||
Args:
|
||||
compute_cost (int): Value of compute cost.
|
||||
"""
|
||||
self.is_int(compute_cost)
|
||||
self.compute_cost_ = compute_cost
|
||||
return self
|
||||
|
||||
def kernel_name(self, kernel_name):
|
||||
"""
|
||||
Register kernel name.
|
||||
|
||||
Args:
|
||||
kernel_name (str): Name of op kernel.
|
||||
"""
|
||||
self.is_string(kernel_name)
|
||||
self.kernel_name_ = kernel_name
|
||||
return self
|
||||
|
||||
def partial_flag(self, partial_flag):
|
||||
"""
|
||||
Register partial flag.
|
||||
|
||||
Args:
|
||||
partial_flag (bool): Value of partial flag.
|
||||
"""
|
||||
self.is_bool(partial_flag)
|
||||
self.partial_flag_ = partial_flag
|
||||
return self
|
||||
|
||||
def reshape_type(self, reshape_type):
|
||||
"""
|
||||
Register reshape type.
|
||||
|
||||
Args:
|
||||
reshape_type (str): Value of reshape type.
|
||||
"""
|
||||
self.is_string(reshape_type)
|
||||
self.reshape_type_ = reshape_type
|
||||
return self
|
||||
|
||||
def dynamic_format(self, dynamic_format):
|
||||
"""
|
||||
Register dynamic format.
|
||||
|
||||
Args:
|
||||
reshape_type (bool): Value of dynamic format.
|
||||
"""
|
||||
self.is_bool(dynamic_format)
|
||||
self.dynamic_format_ = dynamic_format
|
||||
return self
|
||||
|
||||
def op_pattern(self, pattern=None):
|
||||
"""
|
||||
Register op pattern information.
|
||||
|
||||
Args:
|
||||
pattern (str): Value of op pattern.
|
||||
"""
|
||||
if pattern is not None and self.istring(pattern):
|
||||
self.op_pattern_ = pattern
|
||||
return self
|
||||
|
||||
def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs):
|
||||
"""
|
||||
Register op attribute information.
|
||||
|
||||
Args:
|
||||
name (str): Name of the attribute. Default: None.
|
||||
param_type (str): Param type of the attribute. Default: None.
|
||||
type (str): Type of the attribute. Default: None.
|
||||
value (str): Value of the attribute. Default: None.
|
||||
default_value (str): Default value of attribute. Default: None.
|
||||
kwargs (dict): Other information for the attribute.
|
||||
"""
|
||||
param_list = [name, param_type, value_type, value, default_value]
|
||||
attr_dict = {}
|
||||
for index, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
self.is_string(element)
|
||||
if index == 0:
|
||||
attr_dict["name"] = element
|
||||
elif index == 1:
|
||||
attr_dict["param_type"] = element
|
||||
elif index == 2:
|
||||
attr_dict["type"] = element
|
||||
elif index == 3:
|
||||
attr_dict["value"] = element
|
||||
elif index == 4:
|
||||
attr_dict["default_value"] = element
|
||||
if kwargs:
|
||||
attr_dict = dict(attr_dict, **kwargs)
|
||||
self.attr_.append(attr_dict)
|
||||
return self
|
||||
|
||||
def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
|
||||
"""
|
||||
Register op input information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the input. Default: None.
|
||||
name (str): Name of the input. Default: None.
|
||||
need_compile (bool): The input need compile whether or not. Default: None.
|
||||
param_type (str): Type of the input. Default: None.
|
||||
shape (str): Shape of the input. Default: None.
|
||||
kwargs (dict): Other information for the input.
|
||||
"""
|
||||
param_list = [index, name, need_compile, param_type, shape]
|
||||
input_dict = {}
|
||||
for idx, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
if idx == 0:
|
||||
self.is_int(element)
|
||||
input_dict["index"] = element
|
||||
elif idx == 1:
|
||||
self.is_string(element)
|
||||
input_dict["name"] = element
|
||||
elif idx == 2:
|
||||
self.is_bool(element)
|
||||
input_dict["need_compile"] = element
|
||||
elif idx == 3:
|
||||
self.is_string(element)
|
||||
input_dict["param_type"] = element
|
||||
elif idx == 4:
|
||||
self.is_string(element)
|
||||
input_dict["shape"] = element
|
||||
if kwargs:
|
||||
input_dict = dict(input_dict, **kwargs)
|
||||
self.inputs.append(input_dict)
|
||||
return self
|
||||
|
||||
def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
|
||||
"""
|
||||
Register op output information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the output. Default: None.
|
||||
name (str): Name of the output. Default: None.
|
||||
need_compile (bool): The output need compile whether or not. Default: None.
|
||||
param_type (str): Type of the output. Default: None.
|
||||
shape (str): Shape of the output. Default: None.
|
||||
kwargs (dict): Other information for the output.
|
||||
"""
|
||||
param_list = [index, name, need_compile, param_type, shape]
|
||||
output_dict = {}
|
||||
for idx, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
if idx == 0:
|
||||
self.is_int(element)
|
||||
output_dict["index"] = element
|
||||
elif idx == 1:
|
||||
self.is_string(element)
|
||||
output_dict["name"] = element
|
||||
elif idx == 2:
|
||||
self.is_bool(element)
|
||||
output_dict["need_compile"] = element
|
||||
elif idx == 3:
|
||||
self.is_string(element)
|
||||
output_dict["param_type"] = element
|
||||
elif idx == 4:
|
||||
self.is_string(element)
|
||||
output_dict["shape"] = element
|
||||
if kwargs:
|
||||
output_dict = dict(output_dict, **kwargs)
|
||||
self.outputs.append(output_dict)
|
||||
return self
|
||||
|
||||
class DataType():
|
||||
"""
|
||||
Various combinations of dtype and formatself.
|
||||
|
||||
The current list below maybe not completed. If necessary, please add it.
|
||||
"""
|
||||
|
||||
BOOL_None = ("bool", "")
|
||||
BOOL_Default = ("bool", "DefaultFormat")
|
||||
BOOL_5HD = ("bool", "NC1HWC0")
|
||||
BOOL_NCHW = ("bool", "NCHW")
|
||||
BOOL_NHWC = ("bool", "NHWC")
|
||||
BOOL_HWCN = ("bool", "HWCN")
|
||||
|
||||
I8_None = ("int8", "")
|
||||
I8_Default = ("int8", "DefaultFormat")
|
||||
I8_5HD = ("int8", "NC1HWC0")
|
||||
I8_FracZ = ("int8", "Fracz")
|
||||
I8_FracNZ = ("int8", "FRACTAL_NZ")
|
||||
I8_NCHW = ("int8", "NCHW")
|
||||
I8_NHWC = ("int8", "NHWC")
|
||||
I8_HWCN = ("int8", "HWCN")
|
||||
|
||||
U8_None = ("uint8", "")
|
||||
U8_Default = ("uint8", "DefaultFormat")
|
||||
U8_5HD = ("uint8", "NC1HWC0")
|
||||
U8_FracZ = ("uint8", "Fracz")
|
||||
U8_FracNZ = ("uint8", "FRACTAL_NZ")
|
||||
U8_NCHW = ("uint8", "NCHW")
|
||||
U8_NHWC = ("uint8", "NHWC")
|
||||
U8_HWCN = ("uint8", "HWCN")
|
||||
|
||||
I16_None = ("int16", "")
|
||||
I16_Default = ("int16", "DefaultFormat")
|
||||
I16_5HD = ("int16", "NC1HWC0")
|
||||
I16_FracZ = ("int16", "Fracz")
|
||||
I16_FracNZ = ("int16", "FRACTAL_NZ")
|
||||
I16_NCHW = ("int16", "NCHW")
|
||||
I16_NHWC = ("int16", "NHWC")
|
||||
I16_HWCN = ("int16", "HWCN")
|
||||
|
||||
U16_None = ("uint16", "")
|
||||
U16_Default = ("uint16", "DefaultFormat")
|
||||
U16_5HD = ("uint16", "NC1HWC0")
|
||||
U16_FracZ = ("uint16", "Fracz")
|
||||
U16_FracNZ = ("uint16", "FRACTAL_NZ")
|
||||
U16_NCHW = ("uint16", "NCHW")
|
||||
U16_NHWC = ("uint16", "NHWC")
|
||||
U16_HWCN = ("uint16", "HWCN")
|
||||
|
||||
I32_None = ("int32", "")
|
||||
I32_Default = ("int32", "DefaultFormat")
|
||||
I32_5HD = ("int32", "NC1HWC0")
|
||||
I32_FracZ = ("int32", "Fracz")
|
||||
I32_FracNZ = ("int32", "FRACTAL_NZ")
|
||||
I32_NCHW = ("int32", "NCHW")
|
||||
I32_NHWC = ("int32", "NHWC")
|
||||
I32_HWCN = ("int32", "HWCN")
|
||||
|
||||
U32_None = ("uint32", "")
|
||||
U32_Default = ("uint32", "DefaultFormat")
|
||||
U32_5HD = ("uint32", "NC1HWC0")
|
||||
U32_FracZ = ("uint32", "Fracz")
|
||||
U32_FracNZ = ("uint32", "FRACTAL_NZ")
|
||||
U32_NCHW = ("uint32", "NCHW")
|
||||
U32_NHWC = ("uint32", "NHWC")
|
||||
U32_HWCN = ("uint32", "HWCN")
|
||||
|
||||
I64_None = ("int64", "")
|
||||
I64_Default = ("int64", "DefaultFormat")
|
||||
I64_5HD = ("int64", "NC1HWC0")
|
||||
I64_FracZ = ("int64", "Fracz")
|
||||
I64_FracNZ = ("int64", "FRACTAL_NZ")
|
||||
I64_NCHW = ("int64", "NCHW")
|
||||
I64_NHWC = ("int64", "NHWC")
|
||||
I64_HWCN = ("int64", "HWCN")
|
||||
|
||||
U64_None = ("uint64", "")
|
||||
U64_Default = ("uint64", "DefaultFormat")
|
||||
U64_5HD = ("uint64", "NC1HWC0")
|
||||
U64_FracZ = ("uint64", "Fracz")
|
||||
U64_FracNZ = ("uint64", "FRACTAL_NZ")
|
||||
U64_NCHW = ("uint64", "NCHW")
|
||||
U64_NHWC = ("uint64", "NHWC")
|
||||
U64_HWCN = ("uint64", "HWCN")
|
||||
|
||||
F16_None = ("float16", "")
|
||||
F16_Default = ("float16", "DefaultFormat")
|
||||
F16_5HD = ("float16", "NC1HWC0")
|
||||
F16_FracZ = ("float16", "Fracz")
|
||||
F16_FracNZ = ("float16", "FRACTAL_NZ")
|
||||
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
|
||||
F16_NCHW = ("float16", "NCHW")
|
||||
F16_NHWC = ("float16", "NHWC")
|
||||
F16_HWCN = ("float16", "HWCN")
|
||||
|
||||
F32_None = ("float32", "")
|
||||
F32_Default = ("float32", "DefaultFormat")
|
||||
F32_5HD = ("float32", "NC1HWC0")
|
||||
F32_FracZ = ("float32", "Fracz")
|
||||
F32_FracNZ = ("float32", "FRACTAL_NZ")
|
||||
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
|
||||
F32_NCHW = ("float32", "NCHW")
|
||||
F32_NHWC = ("float32", "NHWC")
|
||||
F32_HWCN = ("float32", "HWCN")
|
||||
|
|
Loading…
Reference in New Issue