forked from OSSInnovation/mindspore
!2344 clean pclint warning of the function is out of 50 lines
Merge pull request !2344 from lianliguang/master
This commit is contained in:
commit
a58b1a1435
|
@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto write_input = cnode->input(1);
|
||||
|
||||
if (CheckEltWiseNode(manager.get(), write_input)) {
|
||||
(void)record.insert(write_input);
|
||||
auto input_cnode = write_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
write_input = input_cnode->input(1);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(write_input);
|
||||
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
|
||||
fusion_id_allocator->HasFusionIdAttr(write_input)) {
|
||||
|
@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
|
|||
fusion_id_allocator->HasFusionIdAttr(conv_input)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) {
|
||||
(void)record.insert(conv_input);
|
||||
candidate_fusion->push_back(record);
|
||||
|
|
|
@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
if (do_mask_input_format != kOpFormat_DEFAULT) {
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
|
||||
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
}
|
||||
return nullptr;
|
||||
return RectifyKernelInfoInPynativeProcess(node);
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {
|
||||
return nullptr;
|
||||
|
@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
|
|||
}
|
||||
return convert_format;
|
||||
}
|
||||
|
||||
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
|
||||
const std::string &format) const {
|
||||
for (const auto &do_mask : do_mask_node_list) {
|
||||
|
@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
if (do_mask_input_format != kOpFormat_DEFAULT) {
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
|
||||
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
|
|||
|
||||
private:
|
||||
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const;
|
||||
AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const;
|
||||
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
|
||||
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const;
|
||||
};
|
||||
|
|
|
@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
|||
}
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
|
||||
while (index < input_num) {
|
||||
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
|
||||
++index;
|
||||
MS_EXCEPTION_IF_NULL(replacing_node);
|
||||
if (!replacing_node->isa<CNode>()) {
|
||||
new_depend_inputs.push_back(replacing_node);
|
||||
continue;
|
||||
}
|
||||
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
||||
// Deal with the make_tuple with TransData or Cast inputs.
|
||||
auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode);
|
||||
if (make_tuple_replace_node != nullptr) {
|
||||
new_depend_inputs.push_back(make_tuple_replace_node);
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
||||
if (replace_node == nullptr) {
|
||||
new_depend_inputs.push_back(replacing_node);
|
||||
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: "
|
||||
<< node->DebugString();
|
||||
continue;
|
||||
}
|
||||
auto replace_node = GetConvertNode(func_graph, node, index);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
new_depend_inputs.push_back(replace_node);
|
||||
++index;
|
||||
}
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
CNodePtr new_depend;
|
||||
CNodePtr new_depend = nullptr;
|
||||
if (kernel_graph == nullptr) {
|
||||
new_depend = func_graph->NewCNode(new_depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_depend);
|
||||
|
@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
|||
}
|
||||
return new_depend;
|
||||
}
|
||||
|
||||
const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const size_t index) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto depend_cnode = node->cast<CNodePtr>();
|
||||
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
|
||||
MS_EXCEPTION_IF_NULL(replacing_node);
|
||||
if (!replacing_node->isa<CNode>()) {
|
||||
return replacing_node;
|
||||
}
|
||||
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
||||
// Deal with the make_tuple with TransData or Cast inputs.
|
||||
auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
|
||||
if (make_tuple_replace_node != nullptr) {
|
||||
return make_tuple_replace_node;
|
||||
}
|
||||
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
||||
if (replace_node == nullptr) {
|
||||
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
|
||||
return replacing_node;
|
||||
}
|
||||
return replace_node;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass {
|
|||
~OptimizeDependence() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue