Fix oplib codexx

This commit is contained in:
zjun 2020-04-16 17:27:33 +08:00
parent 6e183fcc0f
commit 6cc5e87126
4 changed files with 21 additions and 25 deletions

View File

@ -39,8 +39,6 @@ namespace mindspore {
namespace kernel {
using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>;
const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint};
bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num,
std::vector<size_t> *input_size_list) {
MS_EXCEPTION_IF_NULL(anf_node);
@ -298,19 +296,12 @@ KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetAnfNode(anf_node);
kernel_mod_ptr->SetNodeName(op_name);
auto iter = std::find(local_framework_op_vec.begin(), local_framework_op_vec.end(), op_name);
if (iter != local_framework_op_vec.end()) {
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
}
} else {
MS_LOG(EXCEPTION) << "Aicpu don't support node [" << op_name << "]";
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
}
if (!SetIOSize(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Set input output size list failed.";
}
return kernel_mod_ptr;
}
} // namespace kernel

View File

@ -94,6 +94,20 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
return ret;
}
void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) {
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
}
bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type,
const std::string& impl_path) {
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
@ -103,17 +117,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
op_info->set_imply_type(imply_type);
op_info->set_fusion_type(obj.at(kFusionType));
if (imply_type == kTBE) {
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
DecodeTBESpecificInfo(obj, op_info);
}
auto attrs = obj.at(kAttr);
for (const auto& attr : attrs) {

View File

@ -40,6 +40,7 @@ class OpLib {
const std::shared_ptr<OpInfo>& op_info);
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
size_t index);
static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info);
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);

View File

@ -57,7 +57,7 @@ def op_info_register(op_info):
return register_decorator
class RegOp():
class RegOp:
"""
Base class for op info register.
@ -483,9 +483,9 @@ class TBERegOp(RegOp):
return self
class DataType():
class DataType:
"""
Various combinations of dtype and formatself.
Various combinations of dtype and format.
The current list below maybe not completed. If necessary, please add it.
"""