add aicpu opinfo register

This commit is contained in:
zjun 2020-04-07 21:52:28 +08:00
parent 8357383111
commit 16296da5c7
14 changed files with 409 additions and 206 deletions

View File

@ -39,45 +39,7 @@ namespace mindspore {
namespace kernel {
using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>;
const std::vector<std::string> local_framework_op_vec = {kInitDataSetQueue, kGetNext, kDropoutGenMask, kPrint};
void InitDataSetQueueAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
MS_EXCEPTION_IF_NULL(node_attr);
std::string channel_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kQueueName);
(*node_attr)[kChannelName].set_s(channel_name);
}
void GetNextAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
MS_EXCEPTION_IF_NULL(node_attr);
std::string shared_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kSharedName);
(*node_attr)[kChannelName].set_s(shared_name);
}
void DropoutGenMaskAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
MS_EXCEPTION_IF_NULL(node_attr);
int seed = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed);
int seed2 = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed2);
(*node_attr)["seed"].set_i(seed);
(*node_attr)["seed2"].set_i(seed2);
}
void CreateAttrFuncMap(std::map<std::string, FNodeAttrHandle> *mOpAttrFuncMap) {
(void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kInitDataSetQueue, InitDataSetQueueAttr));
(void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kGetNext, GetNextAttr));
(void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kDropoutGenMask, DropoutGenMaskAttr));
}
const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint};
bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num,
std::vector<size_t> *input_size_list) {
@ -147,24 +109,74 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
return true;
}
void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value,
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) {
MS_EXCEPTION_IF_NULL(node_attr);
if (type == "int") {
auto attr_value = GetValue<int>(value);
(*node_attr)[attr_name].set_i(attr_value);
} else if (type == "str") {
auto attr_value = GetValue<std::string>(value);
(*node_attr)[attr_name].set_s(attr_value);
} else if (type == "bool") {
auto attr_value = GetValue<bool>(value);
(*node_attr)[attr_name].set_b(attr_value);
} else if (type == "float") {
auto attr_value = GetValue<float>(value);
(*node_attr)[attr_name].set_f(attr_value);
} else if (type == "listInt") {
std::vector<int> attr_value;
auto value_type = value->type();
MS_EXCEPTION_IF_NULL(value_type);
auto value_type_str = value_type->ToString();
if (value_type_str == "Int32") {
int data = GetValue<int>(value);
attr_value.push_back(data);
} else {
attr_value = GetValue<std::vector<int>>(value);
}
mindspore::AttrValue input_shape_attr;
mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array();
MS_EXCEPTION_IF_NULL(input_shape_attr_list);
for (const auto shape : attr_value) {
input_shape_attr_list->add_i(shape);
}
(*node_attr)[attr_name] = input_shape_attr;
} else {
MS_LOG(EXCEPTION) << "type: " << type << "not support";
}
}
void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == "InitDataSetQueue") {
op_name = "InitData";
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
if (op_name == "Print") {
if (op_name == kPrint) {
return;
}
std::map<std::string, FNodeAttrHandle> mOpAttrFuncMap;
CreateAttrFuncMap(&mOpAttrFuncMap);
FNodeAttrHandle func_ptr = nullptr;
auto iter = mOpAttrFuncMap.find(op_name);
if (iter != mOpAttrFuncMap.end()) {
func_ptr = iter->second;
MS_EXCEPTION_IF_NULL(func_ptr);
func_ptr(anf_node, proto);
} else {
MS_LOG(ERROR) << "Don't support node [" << op_name << "] to set nodedef of attr";
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU);
MS_EXCEPTION_IF_NULL(op_info_ptr);
auto attrs_ptr = op_info_ptr->attrs_ptr();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
for (const auto &attr_ptr : attrs_ptr) {
std::string attr_name = attr_ptr->name();
std::string real_name;
auto value = primitive->GetAttr(attr_name);
if (value != nullptr) {
if (attr_name == kQueueName || attr_name == kSharedName) {
real_name = kChannelName;
} else if (attr_name == kSeed) {
real_name = "seed";
} else if (attr_name == kSeed2) {
real_name = "seed2";
}
std::string type = attr_ptr->type();
ParseAttrValue(type, real_name, value, node_attr);
}
}
MS_LOG(INFO) << "Set node attr end!";
}

View File

@ -17,68 +17,27 @@
#include "kernel/aicpu/aicpu_kernel_metadata.h"
#include <memory>
#include <string>
#include "kernel/oplib/oplib.h"
#include "kernel/common_utils.h"
#include "kernel/aicpu/aicpu_util.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
constexpr auto kInitDataSetQueueOpName = "InitDataSetQueue";
constexpr auto kGetNext = "GetNext";
constexpr auto kDropoutGenMask = "DropoutGenMask";
constexpr auto kPrint = "Print";
const std::vector<std::string> AICPU_OPS = {kInitDataSetQueueOpName, kGetNext, kDropoutGenMask, kPrint};
std::shared_ptr<KernelBuildInfo> CreateKernelInfo(const std::vector<std::string> &inputs_format,
const std::vector<TypeId> &inputs_device_type,
const std::vector<std::string> &outputs_format,
const std::vector<TypeId> &outputs_device_type) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_device_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetProcessor(AICPU);
builder.SetKernelType(AICPU_KERNEL);
builder.SetFusionType(OPAQUE);
return builder.Build();
}
bool CheckIfExistAicpuMeta(const std::string &op_name) {
if (std::find(AICPU_OPS.begin(), AICPU_OPS.end(), op_name) != AICPU_OPS.end()) {
return false;
}
return true;
}
void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
MS_LOG(INFO) << "AicpuMetadataInfo.";
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
if (CheckIfExistAicpuMeta(op_name)) {
MS_LOG(DEBUG) << "Aicpu doesn't have metadata of op [" << op_name << "].";
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU);
if (op_info_ptr == nullptr) {
MS_LOG(WARNING) << "Aicpu doestn't have metadata of op [" << op_name << "]";
return;
}
if (op_name == kInitDataSetQueueOpName) {
kernel_info_list->push_back(CreateKernelInfo({}, {}, {}, {}));
}
if (op_name == kGetNext) {
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
kernel_info_list->push_back(CreateKernelInfo({}, {}, outputs_format, outputs_type));
}
if (op_name == kDropoutGenMask) {
kernel_info_list->push_back(CreateKernelInfo({kOpFormat_NCHW, kOpFormat_NCHW},
{kInt32->type_id(), kFloat16->type_id()}, {kOpFormat_NCHW},
{kUInt8->type_id()}));
}
// For compatibility with the current framework
if (op_name == kPrint) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
@ -92,11 +51,20 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
kernel_info_list->push_back(CreateKernelInfo(inputs_format, inputs_type, outputs_format, outputs_type));
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetProcessor(AICPU);
builder.SetKernelType(AICPU_KERNEL);
builder.SetFusionType(OPAQUE);
kernel_info_list->push_back(builder.Build());
return;
}
if (kernel_info_list->empty()) {
MS_LOG(INFO) << "Aicpu dose not has metadata of op[ " << op_name << "].";
if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) {
MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed";
return;
}
}
} // namespace kernel

View File

@ -24,7 +24,8 @@
namespace mindspore {
namespace kernel {
constexpr auto kInitDataSetQueue = "InitData";
constexpr auto kInitDataSetQueue = "InitDataSetQueue";
constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext";
constexpr auto kDropoutGenMask = "DropoutGenMask";
constexpr auto kPrint = "Print";

View File

@ -417,6 +417,8 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu
if (imply_type == kAKG) {
builder->SetKernelType(AUTO_DIFF_KERNEL);
} else if (imply_type == kAICPU) {
builder->SetKernelType(AICPU_KERNEL);
} else {
builder->SetKernelType(TBE_KERNEL);
}
@ -471,6 +473,13 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
return false;
}
kernel_info_list->push_back(builder->Build());
}
} else {
if (processor == AICPU) {
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
SetKernelBuildInfo(builder, processor, op_info_ptr);
kernel_info_list->push_back(builder->Build());
}
}

View File

@ -24,7 +24,7 @@
namespace mindspore {
namespace kernel {
enum OpImplyType { kAKG = 0, kTBE };
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU };
enum OpIOType { kInput = 0, kOutput };
class OpAttr {

View File

@ -39,6 +39,7 @@ constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs";
constexpr auto kOutputs = "outputs";
constexpr auto kAiCPU = "AiCPU";
constexpr auto kTbe = "TBE";
constexpr auto kAkg = "akg";
constexpr auto kAutodiff = "AutoDiff";
@ -60,6 +61,8 @@ std::string ImplTypeToStr(OpImplyType impl_type) {
return kTbe;
case kAKG:
return kAkg;
case kAICPU:
return kAiCPU;
default:
return "unknow";
}
@ -76,6 +79,9 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
} else if (imply_type_string == kAutodiff) {
OpImplyType imply_type = kAKG;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kAiCPU) {
OpImplyType imply_type = kAICPU;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else {
MS_LOG(DEBUG) << "Not support imply_type";
}
@ -154,7 +160,9 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>();
MS_EXCEPTION_IF_NULL(op_attr);
op_attr->set_name(obj.at(kName));
op_attr->set_param_type(obj.at(kParamType));
if (imply_type != kAICPU) {
op_attr->set_param_type(obj.at(kParamType));
}
op_attr->set_type(obj.at(kType));
if (imply_type == kTBE) {
op_attr->set_value(obj.at(kValue));
@ -242,9 +250,10 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice);
if ((is_gpu && imply_type == kTBE) || (!is_gpu && imply_type != kTBE)) {
MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type)
<< "current op num:" << op_info_.size();
if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) ||
(!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) {
MS_LOG(ERROR) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type)
<< ", current op num:" << op_info_.size();
return nullptr;
}
for (const auto& op_info : op_info_) {
@ -253,8 +262,8 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
return op_info;
}
}
MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type)
<< "current op num:" << op_info_.size();
MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type)
<< ", current op num:" << op_info_.size();
return nullptr;
}

View File

@ -30,7 +30,7 @@ Note:
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, TBERegOp, DataType
from .op_info_register import op_info_register, AiCPURegOp, TBERegOp, DataType
from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind
@ -40,6 +40,6 @@ __primitive__ = [
]
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "TBERegOp", "DataType",
"op_info_register", "AiCPURegOp", "TBERegOp", "DataType",
"constexpr"]
__all__.extend(__primitive__)

View File

@ -16,5 +16,6 @@
from .akg.gpu import *
from .tbe import *
from .aicpu import *
__all__ = []

View File

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""aicpu ops"""
from .init_data_set_queue import _init_data_set_queue_aicpu
from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InitDataSetQueue op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
dropout_genmask_op_info = AiCPURegOp("DropoutGenMask") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.attr("Seed0", "int") \
.attr("Seed1", "int") \
.dtype_format(DataType.I32_NCHW, DataType.F16_NCHW, DataType.U8_NCHW) \
.get_op_info()
@op_info_register(dropout_genmask_op_info)
def _dropout_genmask_aicpu():
"""Dropout AiCPU register"""
return

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InitDataSetQueue op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
get_next_op_info = AiCPURegOp("GetNext") \
.fusion_type("OPAQUE") \
.output(0, "y", "dynamic") \
.attr("shared_name", "str") \
.dtype_format(DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default) \
.dtype_format(DataType.I16_Default) \
.dtype_format(DataType.I32_Default) \
.dtype_format(DataType.I64_Default) \
.dtype_format(DataType.F16_Default) \
.dtype_format(DataType.U8_Default) \
.dtype_format(DataType.U16_Default) \
.dtype_format(DataType.U32_Default) \
.dtype_format(DataType.U64_Default) \
.dtype_format(DataType.F32_Default) \
.get_op_info()
@op_info_register(get_next_op_info)
def _get_next_aicpu():
"""GetNext AiCPU register"""
return

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InitDataSetQueue op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp
init_data_set_queue_op_info = AiCPURegOp("InitData") \
.fusion_type("OPAQUE") \
.attr("queue_name", "str") \
.get_op_info()
@op_info_register(init_data_set_queue_op_info)
def _init_data_set_queue_aicpu():
"""InitDataSetQueue AiCPU register"""
return

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InitDataSetQueue op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
print_op_info = AiCPURegOp("Print") \
.fusion_type("OPAQUE") \
.input(0, "x", "dynamic") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(print_op_info)
def _print_aicpu():
"""Print AiCPU register"""
return

View File

@ -78,14 +78,15 @@ class RegOp():
self.inputs = []
self.outputs = []
self.attr_ = []
self.fusion_type_ = ''
self.dtype_format_ = []
def is_string(self, value):
def _is_string(self, value):
"""
Check if the value is a str type.
Args:
value: Parameter to to check.
value: Parameter to be checked.
Raises:
TypeError: If the type of value is not a str.
@ -93,12 +94,12 @@ class RegOp():
if not isinstance(value, str):
raise TypeError("%s value must be str" % str(value))
def is_int(self, value):
def _is_int(self, value):
"""
Check if the value is a int.
Args:
value: Parameter to to check.
value: Parameter to be checked.
Raises:
TypeError: If the type of value is not a int.
@ -106,12 +107,12 @@ class RegOp():
if not isinstance(value, int):
raise TypeError("%s value must be int" % str(value))
def is_bool(self, value):
def _is_bool(self, value):
"""
Check if the value is a bool.
Args:
value: Parameter to to check.
value: Parameter to be checked.
Raises:
TypeError: If the type of value is not a bool.
@ -119,6 +120,51 @@ class RegOp():
if not isinstance(value, bool):
raise TypeError("%s value must be bool" % str(value))
def _check_param(self, param_list, key_list, fn_list, kwargs):
"""
Check if the parameter type is correct.
Args:
param_list (list): Parameter list to be checked.
key_list (list): The keys of output dict.
fn_list (list): Function used for parameter checking. If the function list has only one element,
all parameters will use the same function.
kwargs (dict): Other parameter information.
Raises:
TypeError: If the type of value is not list.
ValueError: If the size of param list is not equal to the size of key list, or
the size of param list is not equal to the size of funtion list.
"""
for i in [param_list, key_list, fn_list]:
if not isinstance(i, list):
raise TypeError("%s value must be list type" % str(i))
if len(param_list) != len(key_list) or (len(fn_list) != 1 and len(param_list) != len(fn_list)):
raise ValueError("param_list size {}, key_list size {}, must be equal.And fn_list size {}.".
format(len(param_list), len(key_list), len(fn_list)))
out_dict = {}
for idx, element in enumerate(param_list):
if element is not None:
if len(fn_list) == 1:
fn_list[0](element)
else:
fn_list[idx](element)
out_dict[key_list[idx]] = element
if kwargs:
out_dict = dict(out_dict, kwargs)
return out_dict
def fusion_type(self, fusion_type):
"""
Register fusion type.
Args:
fusion_type (str): Value of fusion type.
"""
self._is_string(fusion_type)
self.fusion_type_ = fusion_type
return self
def dtype_format(self, *args):
"""
Register dtype and format.
@ -136,8 +182,8 @@ class RegOp():
for arg in args:
if not isinstance(arg, tuple) or len(arg) != 2:
raise ValueError("dtype and format value must be tuple of two elements")
self.is_string(arg[0])
self.is_string(arg[1])
self._is_string(arg[0])
self._is_string(arg[1])
dtype_format.append(arg)
self.dtype_format_.append(tuple(dtype_format))
return self
@ -159,13 +205,71 @@ class RegOp():
return op_info
class AiCPURegOp(RegOp):
"""Class for AiCPU op info register"""
def __init__(self, op_name):
super(AiCPURegOp, self).__init__(op_name)
self.imply_type = "AiCPU"
def input(self, index=None, name=None, param_type=None, **kwargs):
"""
Register AiCPU op input information.
Args:
index (int): Order of the input. Default: None.
name (str): Name of the input. Default: None.
param_type (str): Param type of the input. Default: None.
kwargs (dict): Other information for the input.
"""
param_list = [index, name, param_type]
key_list = ["index", "name", "param_type"]
fn_list = [self._is_int, self._is_string, self._is_string]
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.inputs.append(input_dict)
return self
def output(self, index=None, name=None, param_type=None, **kwargs):
"""
Register AiCPU op output information.
Args:
index (int): Order of the output. Default: None.
name (str): Name of the output. Default: None.
param_type (str): Param type of the output. Default: None.
kwargs (dict): Other information for the output.
"""
param_list = [index, name, param_type]
key_list = ["index", "name", "param_type"]
fn_list = [self._is_int, self._is_string, self._is_string]
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.outputs.append(output_dict)
return self
def attr(self, name=None, value_type=None, value=None, **kwargs):
"""
Register AiCPU op attribute information.
Args:
name (str): Name of the attribute. Default: None.
value_type (str): Value type of the attribute. Default: None.
value (str): Value type of the attribute. Default: None.
kwargs (dict): Other information for the attribute.
"""
param_list = [name, value_type, value]
key_list = ["name", "type", "value"]
fn_list = [self._is_string]
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.attr_.append(attr_dict)
return self
class TBERegOp(RegOp):
"""Class for TBE op info register."""
def __init__(self, op_name=""):
super(TBERegOp, self).__init__(op_name)
self.imply_type = "TBE"
self.fusion_type_ = ''
self.async_flag_ = False
self.binfile_name_ = ''
self.compute_cost_ = 10
@ -175,17 +279,6 @@ class TBERegOp(RegOp):
self.dynamic_format_ = False
self.op_pattern_ = ""
def fusion_type(self, fusion_type):
"""
Register fusion type.
Args:
fusion_type (str): Value of fusion type.
"""
self.is_string(fusion_type)
self.fusion_type_ = fusion_type
return self
def async_flag(self, async_flag):
"""
Register async flag.
@ -193,7 +286,7 @@ class TBERegOp(RegOp):
Args:
async_flag (bool): Value of async flag.
"""
self.is_bool(async_flag)
self._is_bool(async_flag)
self.async_flag_ = async_flag
return self
@ -204,7 +297,7 @@ class TBERegOp(RegOp):
Args:
binfile_name (str): Name of op binfile.
"""
self.is_string(binfile_name)
self._is_string(binfile_name)
self.binfile_name_ = binfile_name
return self
@ -215,7 +308,7 @@ class TBERegOp(RegOp):
Args:
compute_cost (int): Value of compute cost.
"""
self.is_int(compute_cost)
self._is_int(compute_cost)
self.compute_cost_ = compute_cost
return self
@ -226,7 +319,7 @@ class TBERegOp(RegOp):
Args:
kernel_name (str): Name of op kernel.
"""
self.is_string(kernel_name)
self._is_string(kernel_name)
self.kernel_name_ = kernel_name
return self
@ -237,7 +330,7 @@ class TBERegOp(RegOp):
Args:
partial_flag (bool): Value of partial flag.
"""
self.is_bool(partial_flag)
self._is_bool(partial_flag)
self.partial_flag_ = partial_flag
return self
@ -248,7 +341,7 @@ class TBERegOp(RegOp):
Args:
reshape_type (str): Value of reshape type.
"""
self.is_string(reshape_type)
self._is_string(reshape_type)
self.reshape_type_ = reshape_type
return self
@ -259,56 +352,43 @@ class TBERegOp(RegOp):
Args:
reshape_type (bool): Value of dynamic format.
"""
self.is_bool(dynamic_format)
self._is_bool(dynamic_format)
self.dynamic_format_ = dynamic_format
return self
def op_pattern(self, pattern=None):
"""
Register op pattern information.
Register TBE op pattern information.
Args:
pattern (str): Value of op pattern.
"""
if pattern is not None and self.istring(pattern):
if pattern is not None and self._is_string(pattern):
self.op_pattern_ = pattern
return self
def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs):
"""
Register op attribute information.
Register TBE op attribute information.
Args:
name (str): Name of the attribute. Default: None.
param_type (str): Param type of the attribute. Default: None.
type (str): Type of the attribute. Default: None.
value_type (str): Type of the attribute. Default: None.
value (str): Value of the attribute. Default: None.
default_value (str): Default value of attribute. Default: None.
kwargs (dict): Other information for the attribute.
"""
param_list = [name, param_type, value_type, value, default_value]
attr_dict = {}
for index, element in enumerate(param_list):
if element is not None:
self.is_string(element)
if index == 0:
attr_dict["name"] = element
elif index == 1:
attr_dict["param_type"] = element
elif index == 2:
attr_dict["type"] = element
elif index == 3:
attr_dict["value"] = element
elif index == 4:
attr_dict["default_value"] = element
if kwargs:
attr_dict = dict(attr_dict, **kwargs)
key_list = ["name", "param_type", "type", "value", "default_value"]
fn_list = [self._is_string]
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.attr_.append(attr_dict)
return self
def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
"""
Register op input information.
Register TBE op input information.
Args:
index (int): Order of the input. Default: None.
@ -319,32 +399,15 @@ class TBERegOp(RegOp):
kwargs (dict): Other information for the input.
"""
param_list = [index, name, need_compile, param_type, shape]
input_dict = {}
for idx, element in enumerate(param_list):
if element is not None:
if idx == 0:
self.is_int(element)
input_dict["index"] = element
elif idx == 1:
self.is_string(element)
input_dict["name"] = element
elif idx == 2:
self.is_bool(element)
input_dict["need_compile"] = element
elif idx == 3:
self.is_string(element)
input_dict["param_type"] = element
elif idx == 4:
self.is_string(element)
input_dict["shape"] = element
if kwargs:
input_dict = dict(input_dict, **kwargs)
key_list = ["index", "name", "need_compile", "param_type", "shape"]
fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string]
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.inputs.append(input_dict)
return self
def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
"""
Register op output information.
Register TBE op output information.
Args:
index (int): Order of the output. Default: None.
@ -355,29 +418,13 @@ class TBERegOp(RegOp):
kwargs (dict): Other information for the output.
"""
param_list = [index, name, need_compile, param_type, shape]
output_dict = {}
for idx, element in enumerate(param_list):
if element is not None:
if idx == 0:
self.is_int(element)
output_dict["index"] = element
elif idx == 1:
self.is_string(element)
output_dict["name"] = element
elif idx == 2:
self.is_bool(element)
output_dict["need_compile"] = element
elif idx == 3:
self.is_string(element)
output_dict["param_type"] = element
elif idx == 4:
self.is_string(element)
output_dict["shape"] = element
if kwargs:
output_dict = dict(output_dict, **kwargs)
key_list = ["index", "name", "need_compile", "param_type", "shape"]
fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string]
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.outputs.append(output_dict)
return self
class DataType():
"""
Various combinations of dtype and formatself.