!18136 clean code
Merge pull request !18136 from hwjiaorui/new-clean-code
This commit is contained in:
commit
558760cc20
|
@ -101,6 +101,7 @@ constexpr auto kSOC_VERSION = "SOC_VERSION";
|
||||||
constexpr auto kJIsDynamicShape = "is_dynamic_shape";
|
constexpr auto kJIsDynamicShape = "is_dynamic_shape";
|
||||||
constexpr auto kJDynamicIndex = "dynamic_index";
|
constexpr auto kJDynamicIndex = "dynamic_index";
|
||||||
constexpr auto kJSocInfo = "SocInfo";
|
constexpr auto kJSocInfo = "SocInfo";
|
||||||
|
constexpr auto kNCHWShapeSize = 4;
|
||||||
|
|
||||||
const auto kPyPath = "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe";
|
const auto kPyPath = "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe";
|
||||||
|
|
||||||
|
@ -144,7 +145,7 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspor
|
||||||
op_info_json[kJOutputs] = outputs_json;
|
op_info_json[kJOutputs] = outputs_json;
|
||||||
// generate attrs json
|
// generate attrs json
|
||||||
nlohmann::json attrs_json;
|
nlohmann::json attrs_json;
|
||||||
(void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
|
GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
|
||||||
op_info_json[kJAttrs] = attrs_json;
|
op_info_json[kJAttrs] = attrs_json;
|
||||||
auto soc_version = TbeKernelJsonCreator::GetSocVersion();
|
auto soc_version = TbeKernelJsonCreator::GetSocVersion();
|
||||||
op_info_json[kJSocVersion] = soc_version;
|
op_info_json[kJSocVersion] = soc_version;
|
||||||
|
@ -959,7 +960,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_
|
||||||
// !! Note: format: only data node's output use it
|
// !! Note: format: only data node's output use it
|
||||||
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
|
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
|
||||||
if (format == kOpFormat_DEFAULT) {
|
if (format == kOpFormat_DEFAULT) {
|
||||||
format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND;
|
format = ori_shape.size() == kNCHWShapeSize ? kOpFormat_NCHW : kOpFormat_ND;
|
||||||
} else if (format == kOpFormat_FRAC_Z) {
|
} else if (format == kOpFormat_FRAC_Z) {
|
||||||
format = kOpFormat_FRACTAL_Z;
|
format = kOpFormat_FRACTAL_Z;
|
||||||
}
|
}
|
||||||
|
@ -1365,23 +1366,24 @@ bool TbeKernelBuild::CalOutputSize(const nlohmann::json &fusion_op_list,
|
||||||
size_t real_idx = kernel_idx.second;
|
size_t real_idx = kernel_idx.second;
|
||||||
auto full_name = real_node->fullname_with_scope();
|
auto full_name = real_node->fullname_with_scope();
|
||||||
for (const auto &op : fusion_op_list) {
|
for (const auto &op : fusion_op_list) {
|
||||||
if (op[kJName] == full_name) {
|
if (op[kJName] != full_name) {
|
||||||
auto op_output_desces = op[kJOutputDesc];
|
continue;
|
||||||
if (output_node != real_node) {
|
}
|
||||||
// tuple_get item
|
auto op_output_desces = op[kJOutputDesc];
|
||||||
auto output_desc = op_output_desces[real_idx];
|
if (output_node != real_node) {
|
||||||
|
// tuple_get item
|
||||||
|
auto output_desc = op_output_desces[real_idx];
|
||||||
|
if (output_desc[kJShape].empty()) {
|
||||||
|
MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
output_size_list->push_back(GetIOSizeImpl(output_desc));
|
||||||
|
} else {
|
||||||
|
for (const auto &output_desc : op_output_desces) {
|
||||||
if (output_desc[kJShape].empty()) {
|
if (output_desc[kJShape].empty()) {
|
||||||
MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
|
continue;
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
output_size_list->push_back(GetIOSizeImpl(output_desc));
|
output_size_list->push_back(GetIOSizeImpl(output_desc));
|
||||||
} else {
|
|
||||||
for (const auto &output_desc : op_output_desces) {
|
|
||||||
if (output_desc[kJShape].empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
output_size_list->push_back(GetIOSizeImpl(output_desc));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
constexpr size_t INPUT2 = 2;
|
||||||
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
||||||
const session::KernelGraph &kernel_graph,
|
const session::KernelGraph &kernel_graph,
|
||||||
FusedNodeRecord *candidate_fusion) {
|
FusedNodeRecord *candidate_fusion) {
|
||||||
|
@ -48,7 +49,7 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
||||||
}
|
}
|
||||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||||
auto input2 = out_getitem_ptr->input(2);
|
auto input2 = out_getitem_ptr->input(INPUT2);
|
||||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||||
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
|
||||||
} else {
|
} else {
|
||||||
auto tuple_getitem = input_node->cast<CNodePtr>();
|
auto tuple_getitem = input_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||||
int64_t idx = SizeToLong(AnfAlgo::GetTupleGetItemOutIndex(tuple_getitem));
|
size_t idx = AnfAlgo::GetTupleGetItemOutIndex(tuple_getitem);
|
||||||
AnfNodePtr real_input_node = AnfAlgo::GetTupleGetItemRealInput(tuple_getitem);
|
AnfNodePtr real_input_node = AnfAlgo::GetTupleGetItemRealInput(tuple_getitem);
|
||||||
kernel_graph->ReplaceInternalOutput(node, real_input_node, 0, idx);
|
kernel_graph->ReplaceInternalOutput(node, real_input_node, 0, idx);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue