forked from mindspore-Ecosystem/mindspore
add aicpu opinfo register
This commit is contained in:
parent
8357383111
commit
16296da5c7
|
@ -39,45 +39,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>;
|
||||
|
||||
const std::vector<std::string> local_framework_op_vec = {kInitDataSetQueue, kGetNext, kDropoutGenMask, kPrint};
|
||||
|
||||
void InitDataSetQueueAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
|
||||
MS_EXCEPTION_IF_NULL(node_attr);
|
||||
std::string channel_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kQueueName);
|
||||
(*node_attr)[kChannelName].set_s(channel_name);
|
||||
}
|
||||
|
||||
void GetNextAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
|
||||
MS_EXCEPTION_IF_NULL(node_attr);
|
||||
std::string shared_name = AnfAlgo::GetNodeAttr<std::string>(anf_node, kSharedName);
|
||||
(*node_attr)[kChannelName].set_s(shared_name);
|
||||
}
|
||||
|
||||
void DropoutGenMaskAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
|
||||
MS_EXCEPTION_IF_NULL(node_attr);
|
||||
int seed = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed);
|
||||
int seed2 = AnfAlgo::GetNodeAttr<int>(anf_node, kSeed2);
|
||||
(*node_attr)["seed"].set_i(seed);
|
||||
(*node_attr)["seed2"].set_i(seed2);
|
||||
}
|
||||
|
||||
void CreateAttrFuncMap(std::map<std::string, FNodeAttrHandle> *mOpAttrFuncMap) {
|
||||
(void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kInitDataSetQueue, InitDataSetQueueAttr));
|
||||
(void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kGetNext, GetNextAttr));
|
||||
(void)mOpAttrFuncMap->emplace(std::pair<std::string, FNodeAttrHandle>(kDropoutGenMask, DropoutGenMaskAttr));
|
||||
}
|
||||
const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint};
|
||||
|
||||
bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num,
|
||||
std::vector<size_t> *input_size_list) {
|
||||
|
@ -147,24 +109,74 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
|
|||
return true;
|
||||
}
|
||||
|
||||
void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value,
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) {
|
||||
MS_EXCEPTION_IF_NULL(node_attr);
|
||||
if (type == "int") {
|
||||
auto attr_value = GetValue<int>(value);
|
||||
(*node_attr)[attr_name].set_i(attr_value);
|
||||
} else if (type == "str") {
|
||||
auto attr_value = GetValue<std::string>(value);
|
||||
(*node_attr)[attr_name].set_s(attr_value);
|
||||
} else if (type == "bool") {
|
||||
auto attr_value = GetValue<bool>(value);
|
||||
(*node_attr)[attr_name].set_b(attr_value);
|
||||
} else if (type == "float") {
|
||||
auto attr_value = GetValue<float>(value);
|
||||
(*node_attr)[attr_name].set_f(attr_value);
|
||||
} else if (type == "listInt") {
|
||||
std::vector<int> attr_value;
|
||||
auto value_type = value->type();
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
auto value_type_str = value_type->ToString();
|
||||
if (value_type_str == "Int32") {
|
||||
int data = GetValue<int>(value);
|
||||
attr_value.push_back(data);
|
||||
} else {
|
||||
attr_value = GetValue<std::vector<int>>(value);
|
||||
}
|
||||
mindspore::AttrValue input_shape_attr;
|
||||
mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_attr_list);
|
||||
for (const auto shape : attr_value) {
|
||||
input_shape_attr_list->add_i(shape);
|
||||
}
|
||||
(*node_attr)[attr_name] = input_shape_attr;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "type: " << type << "not support";
|
||||
}
|
||||
}
|
||||
|
||||
void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
|
||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (op_name == "InitDataSetQueue") {
|
||||
op_name = "InitData";
|
||||
if (op_name == kInitDataSetQueue) {
|
||||
op_name = kInitData;
|
||||
}
|
||||
if (op_name == "Print") {
|
||||
if (op_name == kPrint) {
|
||||
return;
|
||||
}
|
||||
std::map<std::string, FNodeAttrHandle> mOpAttrFuncMap;
|
||||
CreateAttrFuncMap(&mOpAttrFuncMap);
|
||||
FNodeAttrHandle func_ptr = nullptr;
|
||||
auto iter = mOpAttrFuncMap.find(op_name);
|
||||
if (iter != mOpAttrFuncMap.end()) {
|
||||
func_ptr = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(func_ptr);
|
||||
func_ptr(anf_node, proto);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Don't support node [" << op_name << "] to set nodedef of attr";
|
||||
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU);
|
||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
||||
auto attrs_ptr = op_info_ptr->attrs_ptr();
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
|
||||
for (const auto &attr_ptr : attrs_ptr) {
|
||||
std::string attr_name = attr_ptr->name();
|
||||
std::string real_name;
|
||||
auto value = primitive->GetAttr(attr_name);
|
||||
if (value != nullptr) {
|
||||
if (attr_name == kQueueName || attr_name == kSharedName) {
|
||||
real_name = kChannelName;
|
||||
} else if (attr_name == kSeed) {
|
||||
real_name = "seed";
|
||||
} else if (attr_name == kSeed2) {
|
||||
real_name = "seed2";
|
||||
}
|
||||
std::string type = attr_ptr->type();
|
||||
ParseAttrValue(type, real_name, value, node_attr);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Set node attr end!";
|
||||
}
|
||||
|
|
|
@ -17,68 +17,27 @@
|
|||
#include "kernel/aicpu/aicpu_kernel_metadata.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "kernel/aicpu/aicpu_util.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kInitDataSetQueueOpName = "InitDataSetQueue";
|
||||
constexpr auto kGetNext = "GetNext";
|
||||
constexpr auto kDropoutGenMask = "DropoutGenMask";
|
||||
constexpr auto kPrint = "Print";
|
||||
const std::vector<std::string> AICPU_OPS = {kInitDataSetQueueOpName, kGetNext, kDropoutGenMask, kPrint};
|
||||
|
||||
std::shared_ptr<KernelBuildInfo> CreateKernelInfo(const std::vector<std::string> &inputs_format,
|
||||
const std::vector<TypeId> &inputs_device_type,
|
||||
const std::vector<std::string> &outputs_format,
|
||||
const std::vector<TypeId> &outputs_device_type) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetInputsDeviceType(inputs_device_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
builder.SetProcessor(AICPU);
|
||||
builder.SetKernelType(AICPU_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool CheckIfExistAicpuMeta(const std::string &op_name) {
|
||||
if (std::find(AICPU_OPS.begin(), AICPU_OPS.end(), op_name) != AICPU_OPS.end()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_LOG(INFO) << "AicpuMetadataInfo.";
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (CheckIfExistAicpuMeta(op_name)) {
|
||||
MS_LOG(DEBUG) << "Aicpu doesn't have metadata of op [" << op_name << "].";
|
||||
if (op_name == kInitDataSetQueue) {
|
||||
op_name = kInitData;
|
||||
}
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU);
|
||||
if (op_info_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Aicpu doestn't have metadata of op [" << op_name << "]";
|
||||
return;
|
||||
}
|
||||
|
||||
if (op_name == kInitDataSetQueueOpName) {
|
||||
kernel_info_list->push_back(CreateKernelInfo({}, {}, {}, {}));
|
||||
}
|
||||
|
||||
if (op_name == kGetNext) {
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
kernel_info_list->push_back(CreateKernelInfo({}, {}, outputs_format, outputs_type));
|
||||
}
|
||||
|
||||
if (op_name == kDropoutGenMask) {
|
||||
kernel_info_list->push_back(CreateKernelInfo({kOpFormat_NCHW, kOpFormat_NCHW},
|
||||
{kInt32->type_id(), kFloat16->type_id()}, {kOpFormat_NCHW},
|
||||
{kUInt8->type_id()}));
|
||||
}
|
||||
|
||||
// For compatibility with the current framework
|
||||
if (op_name == kPrint) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
|
@ -92,11 +51,20 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
|
|||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
kernel_info_list->push_back(CreateKernelInfo(inputs_format, inputs_type, outputs_format, outputs_type));
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetProcessor(AICPU);
|
||||
builder.SetKernelType(AICPU_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
kernel_info_list->push_back(builder.Build());
|
||||
return;
|
||||
}
|
||||
|
||||
if (kernel_info_list->empty()) {
|
||||
MS_LOG(INFO) << "Aicpu dose not has metadata of op[ " << op_name << "].";
|
||||
if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) {
|
||||
MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed";
|
||||
return;
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kInitDataSetQueue = "InitData";
|
||||
constexpr auto kInitDataSetQueue = "InitDataSetQueue";
|
||||
constexpr auto kInitData = "InitData";
|
||||
constexpr auto kGetNext = "GetNext";
|
||||
constexpr auto kDropoutGenMask = "DropoutGenMask";
|
||||
constexpr auto kPrint = "Print";
|
||||
|
|
|
@ -417,6 +417,8 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu
|
|||
|
||||
if (imply_type == kAKG) {
|
||||
builder->SetKernelType(AUTO_DIFF_KERNEL);
|
||||
} else if (imply_type == kAICPU) {
|
||||
builder->SetKernelType(AICPU_KERNEL);
|
||||
} else {
|
||||
builder->SetKernelType(TBE_KERNEL);
|
||||
}
|
||||
|
@ -471,6 +473,13 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
|
|||
return false;
|
||||
}
|
||||
|
||||
kernel_info_list->push_back(builder->Build());
|
||||
}
|
||||
} else {
|
||||
if (processor == AICPU) {
|
||||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
SetKernelBuildInfo(builder, processor, op_info_ptr);
|
||||
kernel_info_list->push_back(builder->Build());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum OpImplyType { kAKG = 0, kTBE };
|
||||
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU };
|
||||
enum OpIOType { kInput = 0, kOutput };
|
||||
|
||||
class OpAttr {
|
||||
|
|
|
@ -39,6 +39,7 @@ constexpr auto kDtypeFormat = "dtype_format";
|
|||
constexpr auto kAttr = "attr";
|
||||
constexpr auto kIputs = "inputs";
|
||||
constexpr auto kOutputs = "outputs";
|
||||
constexpr auto kAiCPU = "AiCPU";
|
||||
constexpr auto kTbe = "TBE";
|
||||
constexpr auto kAkg = "akg";
|
||||
constexpr auto kAutodiff = "AutoDiff";
|
||||
|
@ -60,6 +61,8 @@ std::string ImplTypeToStr(OpImplyType impl_type) {
|
|||
return kTbe;
|
||||
case kAKG:
|
||||
return kAkg;
|
||||
case kAICPU:
|
||||
return kAiCPU;
|
||||
default:
|
||||
return "unknow";
|
||||
}
|
||||
|
@ -76,6 +79,9 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
|
|||
} else if (imply_type_string == kAutodiff) {
|
||||
OpImplyType imply_type = kAKG;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else if (imply_type_string == kAiCPU) {
|
||||
OpImplyType imply_type = kAICPU;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Not support imply_type";
|
||||
}
|
||||
|
@ -154,7 +160,9 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
|
|||
std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>();
|
||||
MS_EXCEPTION_IF_NULL(op_attr);
|
||||
op_attr->set_name(obj.at(kName));
|
||||
op_attr->set_param_type(obj.at(kParamType));
|
||||
if (imply_type != kAICPU) {
|
||||
op_attr->set_param_type(obj.at(kParamType));
|
||||
}
|
||||
op_attr->set_type(obj.at(kType));
|
||||
if (imply_type == kTBE) {
|
||||
op_attr->set_value(obj.at(kValue));
|
||||
|
@ -242,9 +250,10 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
|
|||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_gpu = (context->device_target() == kGPUDevice);
|
||||
if ((is_gpu && imply_type == kTBE) || (!is_gpu && imply_type != kTBE)) {
|
||||
MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type)
|
||||
<< "current op num:" << op_info_.size();
|
||||
if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) ||
|
||||
(!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) {
|
||||
MS_LOG(ERROR) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type)
|
||||
<< ", current op num:" << op_info_.size();
|
||||
return nullptr;
|
||||
}
|
||||
for (const auto& op_info : op_info_) {
|
||||
|
@ -253,8 +262,8 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
|
|||
return op_info;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << "imply_type:" << ImplTypeToStr(imply_type)
|
||||
<< "current op num:" << op_info_.size();
|
||||
MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type)
|
||||
<< ", current op num:" << op_info_.size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ Note:
|
|||
|
||||
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
||||
from .op_info_register import op_info_register, TBERegOp, DataType
|
||||
from .op_info_register import op_info_register, AiCPURegOp, TBERegOp, DataType
|
||||
from .primitive import constexpr
|
||||
from .._c_expression import signature_rw, signature_kind
|
||||
|
||||
|
@ -40,6 +40,6 @@ __primitive__ = [
|
|||
]
|
||||
|
||||
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
||||
"op_info_register", "TBERegOp", "DataType",
|
||||
"op_info_register", "AiCPURegOp", "TBERegOp", "DataType",
|
||||
"constexpr"]
|
||||
__all__.extend(__primitive__)
|
||||
|
|
|
@ -16,5 +16,6 @@
|
|||
|
||||
from .akg.gpu import *
|
||||
from .tbe import *
|
||||
from .aicpu import *
|
||||
|
||||
__all__ = []
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""aicpu ops"""
|
||||
from .init_data_set_queue import _init_data_set_queue_aicpu
|
||||
from .dropout_genmask import _dropout_genmask_aicpu
|
||||
from .get_next import _get_next_aicpu
|
||||
from .print_tensor import _print_aicpu
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""InitDataSetQueue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
dropout_genmask_op_info = AiCPURegOp("DropoutGenMask") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("Seed0", "int") \
|
||||
.attr("Seed1", "int") \
|
||||
.dtype_format(DataType.I32_NCHW, DataType.F16_NCHW, DataType.U8_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(dropout_genmask_op_info)
|
||||
def _dropout_genmask_aicpu():
|
||||
"""Dropout AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""InitDataSetQueue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
get_next_op_info = AiCPURegOp("GetNext") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.output(0, "y", "dynamic") \
|
||||
.attr("shared_name", "str") \
|
||||
.dtype_format(DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default) \
|
||||
.dtype_format(DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(get_next_op_info)
|
||||
def _get_next_aicpu():
|
||||
"""GetNext AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""InitDataSetQueue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp
|
||||
|
||||
init_data_set_queue_op_info = AiCPURegOp("InitData") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("queue_name", "str") \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(init_data_set_queue_op_info)
|
||||
def _init_data_set_queue_aicpu():
|
||||
"""InitDataSetQueue AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""InitDataSetQueue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
print_op_info = AiCPURegOp("Print") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "dynamic") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(print_op_info)
|
||||
def _print_aicpu():
|
||||
"""Print AiCPU register"""
|
||||
return
|
|
@ -78,14 +78,15 @@ class RegOp():
|
|||
self.inputs = []
|
||||
self.outputs = []
|
||||
self.attr_ = []
|
||||
self.fusion_type_ = ''
|
||||
self.dtype_format_ = []
|
||||
|
||||
def is_string(self, value):
|
||||
def _is_string(self, value):
|
||||
"""
|
||||
Check if the value is a str type.
|
||||
|
||||
Args:
|
||||
value: Parameter to to check.
|
||||
value: Parameter to be checked.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not a str.
|
||||
|
@ -93,12 +94,12 @@ class RegOp():
|
|||
if not isinstance(value, str):
|
||||
raise TypeError("%s value must be str" % str(value))
|
||||
|
||||
def is_int(self, value):
|
||||
def _is_int(self, value):
|
||||
"""
|
||||
Check if the value is a int.
|
||||
|
||||
Args:
|
||||
value: Parameter to to check.
|
||||
value: Parameter to be checked.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not a int.
|
||||
|
@ -106,12 +107,12 @@ class RegOp():
|
|||
if not isinstance(value, int):
|
||||
raise TypeError("%s value must be int" % str(value))
|
||||
|
||||
def is_bool(self, value):
|
||||
def _is_bool(self, value):
|
||||
"""
|
||||
Check if the value is a bool.
|
||||
|
||||
Args:
|
||||
value: Parameter to to check.
|
||||
value: Parameter to be checked.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not a bool.
|
||||
|
@ -119,6 +120,51 @@ class RegOp():
|
|||
if not isinstance(value, bool):
|
||||
raise TypeError("%s value must be bool" % str(value))
|
||||
|
||||
def _check_param(self, param_list, key_list, fn_list, kwargs):
|
||||
"""
|
||||
Check if the parameter type is correct.
|
||||
|
||||
Args:
|
||||
param_list (list): Parameter list to be checked.
|
||||
key_list (list): The keys of output dict.
|
||||
fn_list (list): Function used for parameter checking. If the function list has only one element,
|
||||
all parameters will use the same function.
|
||||
kwargs (dict): Other parameter information.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of value is not list.
|
||||
ValueError: If the size of param list is not equal to the size of key list, or
|
||||
the size of param list is not equal to the size of funtion list.
|
||||
"""
|
||||
for i in [param_list, key_list, fn_list]:
|
||||
if not isinstance(i, list):
|
||||
raise TypeError("%s value must be list type" % str(i))
|
||||
if len(param_list) != len(key_list) or (len(fn_list) != 1 and len(param_list) != len(fn_list)):
|
||||
raise ValueError("param_list size {}, key_list size {}, must be equal.And fn_list size {}.".
|
||||
format(len(param_list), len(key_list), len(fn_list)))
|
||||
out_dict = {}
|
||||
for idx, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
if len(fn_list) == 1:
|
||||
fn_list[0](element)
|
||||
else:
|
||||
fn_list[idx](element)
|
||||
out_dict[key_list[idx]] = element
|
||||
if kwargs:
|
||||
out_dict = dict(out_dict, kwargs)
|
||||
return out_dict
|
||||
|
||||
def fusion_type(self, fusion_type):
|
||||
"""
|
||||
Register fusion type.
|
||||
|
||||
Args:
|
||||
fusion_type (str): Value of fusion type.
|
||||
"""
|
||||
self._is_string(fusion_type)
|
||||
self.fusion_type_ = fusion_type
|
||||
return self
|
||||
|
||||
def dtype_format(self, *args):
|
||||
"""
|
||||
Register dtype and format.
|
||||
|
@ -136,8 +182,8 @@ class RegOp():
|
|||
for arg in args:
|
||||
if not isinstance(arg, tuple) or len(arg) != 2:
|
||||
raise ValueError("dtype and format value must be tuple of two elements")
|
||||
self.is_string(arg[0])
|
||||
self.is_string(arg[1])
|
||||
self._is_string(arg[0])
|
||||
self._is_string(arg[1])
|
||||
dtype_format.append(arg)
|
||||
self.dtype_format_.append(tuple(dtype_format))
|
||||
return self
|
||||
|
@ -159,13 +205,71 @@ class RegOp():
|
|||
return op_info
|
||||
|
||||
|
||||
class AiCPURegOp(RegOp):
|
||||
"""Class for AiCPU op info register"""
|
||||
|
||||
def __init__(self, op_name):
|
||||
super(AiCPURegOp, self).__init__(op_name)
|
||||
self.imply_type = "AiCPU"
|
||||
|
||||
def input(self, index=None, name=None, param_type=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op input information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the input. Default: None.
|
||||
name (str): Name of the input. Default: None.
|
||||
param_type (str): Param type of the input. Default: None.
|
||||
kwargs (dict): Other information for the input.
|
||||
"""
|
||||
param_list = [index, name, param_type]
|
||||
key_list = ["index", "name", "param_type"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.inputs.append(input_dict)
|
||||
return self
|
||||
|
||||
def output(self, index=None, name=None, param_type=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op output information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the output. Default: None.
|
||||
name (str): Name of the output. Default: None.
|
||||
param_type (str): Param type of the output. Default: None.
|
||||
kwargs (dict): Other information for the output.
|
||||
"""
|
||||
param_list = [index, name, param_type]
|
||||
key_list = ["index", "name", "param_type"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.outputs.append(output_dict)
|
||||
return self
|
||||
|
||||
def attr(self, name=None, value_type=None, value=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op attribute information.
|
||||
|
||||
Args:
|
||||
name (str): Name of the attribute. Default: None.
|
||||
value_type (str): Value type of the attribute. Default: None.
|
||||
value (str): Value type of the attribute. Default: None.
|
||||
kwargs (dict): Other information for the attribute.
|
||||
"""
|
||||
param_list = [name, value_type, value]
|
||||
key_list = ["name", "type", "value"]
|
||||
fn_list = [self._is_string]
|
||||
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.attr_.append(attr_dict)
|
||||
return self
|
||||
|
||||
|
||||
class TBERegOp(RegOp):
|
||||
"""Class for TBE op info register."""
|
||||
|
||||
def __init__(self, op_name=""):
|
||||
super(TBERegOp, self).__init__(op_name)
|
||||
self.imply_type = "TBE"
|
||||
self.fusion_type_ = ''
|
||||
self.async_flag_ = False
|
||||
self.binfile_name_ = ''
|
||||
self.compute_cost_ = 10
|
||||
|
@ -175,17 +279,6 @@ class TBERegOp(RegOp):
|
|||
self.dynamic_format_ = False
|
||||
self.op_pattern_ = ""
|
||||
|
||||
def fusion_type(self, fusion_type):
|
||||
"""
|
||||
Register fusion type.
|
||||
|
||||
Args:
|
||||
fusion_type (str): Value of fusion type.
|
||||
"""
|
||||
self.is_string(fusion_type)
|
||||
self.fusion_type_ = fusion_type
|
||||
return self
|
||||
|
||||
def async_flag(self, async_flag):
|
||||
"""
|
||||
Register async flag.
|
||||
|
@ -193,7 +286,7 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
async_flag (bool): Value of async flag.
|
||||
"""
|
||||
self.is_bool(async_flag)
|
||||
self._is_bool(async_flag)
|
||||
self.async_flag_ = async_flag
|
||||
return self
|
||||
|
||||
|
@ -204,7 +297,7 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
binfile_name (str): Name of op binfile.
|
||||
"""
|
||||
self.is_string(binfile_name)
|
||||
self._is_string(binfile_name)
|
||||
self.binfile_name_ = binfile_name
|
||||
return self
|
||||
|
||||
|
@ -215,7 +308,7 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
compute_cost (int): Value of compute cost.
|
||||
"""
|
||||
self.is_int(compute_cost)
|
||||
self._is_int(compute_cost)
|
||||
self.compute_cost_ = compute_cost
|
||||
return self
|
||||
|
||||
|
@ -226,7 +319,7 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
kernel_name (str): Name of op kernel.
|
||||
"""
|
||||
self.is_string(kernel_name)
|
||||
self._is_string(kernel_name)
|
||||
self.kernel_name_ = kernel_name
|
||||
return self
|
||||
|
||||
|
@ -237,7 +330,7 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
partial_flag (bool): Value of partial flag.
|
||||
"""
|
||||
self.is_bool(partial_flag)
|
||||
self._is_bool(partial_flag)
|
||||
self.partial_flag_ = partial_flag
|
||||
return self
|
||||
|
||||
|
@ -248,7 +341,7 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
reshape_type (str): Value of reshape type.
|
||||
"""
|
||||
self.is_string(reshape_type)
|
||||
self._is_string(reshape_type)
|
||||
self.reshape_type_ = reshape_type
|
||||
return self
|
||||
|
||||
|
@ -259,56 +352,43 @@ class TBERegOp(RegOp):
|
|||
Args:
|
||||
reshape_type (bool): Value of dynamic format.
|
||||
"""
|
||||
self.is_bool(dynamic_format)
|
||||
self._is_bool(dynamic_format)
|
||||
self.dynamic_format_ = dynamic_format
|
||||
return self
|
||||
|
||||
def op_pattern(self, pattern=None):
|
||||
"""
|
||||
Register op pattern information.
|
||||
Register TBE op pattern information.
|
||||
|
||||
Args:
|
||||
pattern (str): Value of op pattern.
|
||||
"""
|
||||
if pattern is not None and self.istring(pattern):
|
||||
if pattern is not None and self._is_string(pattern):
|
||||
self.op_pattern_ = pattern
|
||||
return self
|
||||
|
||||
def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs):
|
||||
"""
|
||||
Register op attribute information.
|
||||
Register TBE op attribute information.
|
||||
|
||||
Args:
|
||||
name (str): Name of the attribute. Default: None.
|
||||
param_type (str): Param type of the attribute. Default: None.
|
||||
type (str): Type of the attribute. Default: None.
|
||||
value_type (str): Type of the attribute. Default: None.
|
||||
value (str): Value of the attribute. Default: None.
|
||||
default_value (str): Default value of attribute. Default: None.
|
||||
kwargs (dict): Other information for the attribute.
|
||||
"""
|
||||
param_list = [name, param_type, value_type, value, default_value]
|
||||
attr_dict = {}
|
||||
for index, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
self.is_string(element)
|
||||
if index == 0:
|
||||
attr_dict["name"] = element
|
||||
elif index == 1:
|
||||
attr_dict["param_type"] = element
|
||||
elif index == 2:
|
||||
attr_dict["type"] = element
|
||||
elif index == 3:
|
||||
attr_dict["value"] = element
|
||||
elif index == 4:
|
||||
attr_dict["default_value"] = element
|
||||
if kwargs:
|
||||
attr_dict = dict(attr_dict, **kwargs)
|
||||
key_list = ["name", "param_type", "type", "value", "default_value"]
|
||||
fn_list = [self._is_string]
|
||||
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.attr_.append(attr_dict)
|
||||
return self
|
||||
|
||||
def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
|
||||
"""
|
||||
Register op input information.
|
||||
Register TBE op input information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the input. Default: None.
|
||||
|
@ -319,32 +399,15 @@ class TBERegOp(RegOp):
|
|||
kwargs (dict): Other information for the input.
|
||||
"""
|
||||
param_list = [index, name, need_compile, param_type, shape]
|
||||
input_dict = {}
|
||||
for idx, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
if idx == 0:
|
||||
self.is_int(element)
|
||||
input_dict["index"] = element
|
||||
elif idx == 1:
|
||||
self.is_string(element)
|
||||
input_dict["name"] = element
|
||||
elif idx == 2:
|
||||
self.is_bool(element)
|
||||
input_dict["need_compile"] = element
|
||||
elif idx == 3:
|
||||
self.is_string(element)
|
||||
input_dict["param_type"] = element
|
||||
elif idx == 4:
|
||||
self.is_string(element)
|
||||
input_dict["shape"] = element
|
||||
if kwargs:
|
||||
input_dict = dict(input_dict, **kwargs)
|
||||
key_list = ["index", "name", "need_compile", "param_type", "shape"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string]
|
||||
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.inputs.append(input_dict)
|
||||
return self
|
||||
|
||||
def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
|
||||
"""
|
||||
Register op output information.
|
||||
Register TBE op output information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the output. Default: None.
|
||||
|
@ -355,29 +418,13 @@ class TBERegOp(RegOp):
|
|||
kwargs (dict): Other information for the output.
|
||||
"""
|
||||
param_list = [index, name, need_compile, param_type, shape]
|
||||
output_dict = {}
|
||||
for idx, element in enumerate(param_list):
|
||||
if element is not None:
|
||||
if idx == 0:
|
||||
self.is_int(element)
|
||||
output_dict["index"] = element
|
||||
elif idx == 1:
|
||||
self.is_string(element)
|
||||
output_dict["name"] = element
|
||||
elif idx == 2:
|
||||
self.is_bool(element)
|
||||
output_dict["need_compile"] = element
|
||||
elif idx == 3:
|
||||
self.is_string(element)
|
||||
output_dict["param_type"] = element
|
||||
elif idx == 4:
|
||||
self.is_string(element)
|
||||
output_dict["shape"] = element
|
||||
if kwargs:
|
||||
output_dict = dict(output_dict, **kwargs)
|
||||
key_list = ["index", "name", "need_compile", "param_type", "shape"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string]
|
||||
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.outputs.append(output_dict)
|
||||
return self
|
||||
|
||||
|
||||
class DataType():
|
||||
"""
|
||||
Various combinations of dtype and formatself.
|
||||
|
|
Loading…
Reference in New Issue