forked from OSSInnovation/mindspore
!395 Fix oplib coddex
Merge pull request !395 from zjun/fix_oplib_codexx
This commit is contained in:
commit
3c317c71f4
|
@ -39,8 +39,6 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>;
|
||||
|
||||
const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint};
|
||||
|
||||
bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num,
|
||||
std::vector<size_t> *input_size_list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
|
@ -298,19 +296,12 @@ KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
kernel_mod_ptr->SetAnfNode(anf_node);
|
||||
kernel_mod_ptr->SetNodeName(op_name);
|
||||
auto iter = std::find(local_framework_op_vec.begin(), local_framework_op_vec.end(), op_name);
|
||||
if (iter != local_framework_op_vec.end()) {
|
||||
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
|
||||
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Aicpu don't support node [" << op_name << "]";
|
||||
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
|
||||
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
|
||||
}
|
||||
|
||||
if (!SetIOSize(anf_node, kernel_mod_ptr)) {
|
||||
MS_LOG(EXCEPTION) << "Set input output size list failed.";
|
||||
}
|
||||
|
||||
return kernel_mod_ptr;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -94,6 +94,20 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
|
|||
return ret;
|
||||
}
|
||||
|
||||
void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) {
|
||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||
op_info->set_binfile_name(obj.at(kBinfileName));
|
||||
op_info->set_compute_cost(obj.at(kComputeCost));
|
||||
op_info->set_kernel_name(obj.at(kKernelName));
|
||||
op_info->set_partial_flag(obj.at(kPartialFlag));
|
||||
if (obj.find(kOpPattern) != obj.end()) {
|
||||
op_info->set_op_pattern(obj.at(kOpPattern));
|
||||
}
|
||||
if (obj.find(kDynamicFormat) != obj.end()) {
|
||||
op_info->set_dynamic_format(obj.at(kDynamicFormat));
|
||||
}
|
||||
}
|
||||
|
||||
bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type,
|
||||
const std::string& impl_path) {
|
||||
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
|
||||
|
@ -103,17 +117,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
|
|||
op_info->set_imply_type(imply_type);
|
||||
op_info->set_fusion_type(obj.at(kFusionType));
|
||||
if (imply_type == kTBE) {
|
||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||
op_info->set_binfile_name(obj.at(kBinfileName));
|
||||
op_info->set_compute_cost(obj.at(kComputeCost));
|
||||
op_info->set_kernel_name(obj.at(kKernelName));
|
||||
op_info->set_partial_flag(obj.at(kPartialFlag));
|
||||
if (obj.find(kOpPattern) != obj.end()) {
|
||||
op_info->set_op_pattern(obj.at(kOpPattern));
|
||||
}
|
||||
if (obj.find(kDynamicFormat) != obj.end()) {
|
||||
op_info->set_dynamic_format(obj.at(kDynamicFormat));
|
||||
}
|
||||
DecodeTBESpecificInfo(obj, op_info);
|
||||
}
|
||||
auto attrs = obj.at(kAttr);
|
||||
for (const auto& attr : attrs) {
|
||||
|
|
|
@ -40,6 +40,7 @@ class OpLib {
|
|||
const std::shared_ptr<OpInfo>& op_info);
|
||||
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
|
||||
size_t index);
|
||||
static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info);
|
||||
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
|
||||
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
|
||||
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);
|
||||
|
|
|
@ -57,7 +57,7 @@ def op_info_register(op_info):
|
|||
return register_decorator
|
||||
|
||||
|
||||
class RegOp():
|
||||
class RegOp:
|
||||
"""
|
||||
Base class for op info register.
|
||||
|
||||
|
@ -483,9 +483,9 @@ class TBERegOp(RegOp):
|
|||
return self
|
||||
|
||||
|
||||
class DataType():
|
||||
class DataType:
|
||||
"""
|
||||
Various combinations of dtype and formatself.
|
||||
Various combinations of dtype and format.
|
||||
|
||||
The current list below maybe not completed. If necessary, please add it.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue