Merge pull request !31647 from baihuawei/cleancode1.7
This commit is contained in:
i-robot 2022-03-22 02:44:17 +00:00 committed by Gitee
commit 9c6fb2a67e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 10 additions and 20 deletions

View File

@ -185,6 +185,8 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) { if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) {
auto threshold = parallel_context->dp_fusion_threshold_mb(); auto threshold = parallel_context->dp_fusion_threshold_mb();
GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index); GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index);
MS_LOG(INFO) << "The split threshold for AllReduce is " << threshold << ", the segment num is "
<< segment_index->size();
} }
return CheckSegments(communication_op_node_size, segment_index); return CheckSegments(communication_op_node_size, segment_index);
} }

View File

@ -1050,7 +1050,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
} }
} }
} }
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i); common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
} }
auto eval_result = opt::CppInferShape(primitive, args_spec_list); auto eval_result = opt::CppInferShape(primitive, args_spec_list);
node->set_abstract(eval_result); node->set_abstract(eval_result);

View File

@ -170,7 +170,7 @@ class COMMON_EXPORT AnfAlgo {
static bool IsHostKernel(const CNodePtr &node); static bool IsHostKernel(const CNodePtr &node);
// return true if use cnode_input's abstract, false if use real_input's abstract // return true if use cnode_input's abstract, false if use real_input's abstract
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input, static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
const AnfNodePtr &real_input, size_t index); const AnfNodePtr &real_input);
// Find real input nodes. // Find real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited); std::set<AnfNodePtr> *visited);

View File

@ -95,7 +95,7 @@ void KernelMod::InferShape() {
tuple_elements->set_value(out_tensor); tuple_elements->set_value(out_tensor);
} }
} }
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i); common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
} }
auto eval_result = opt::CppInferShape(primitive, args_spec_list); auto eval_result = opt::CppInferShape(primitive, args_spec_list);
cnode->set_abstract(eval_result); cnode->set_abstract(eval_result);

View File

@ -34,11 +34,11 @@ class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass); PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
} }
~BatchMatmulFusedMulAddFusionPass() override = default; ~BatchMatmulFusedMulAddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private: private:
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion); FusedNodeRecord *candidate_fusion);
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -32,10 +32,10 @@ class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator) explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {} : FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {}
~MatmulDropoutDoMaskV3AddFusionPass() override = default; ~MatmulDropoutDoMaskV3AddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private: private:
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion); void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -64,7 +64,7 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod {
InitSizeLists(); InitSizeLists();
return true; return true;
} }
need_broadcast_ = IsBroadcast(input_shape1, input_shape2); need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(input_shape1, input_shape2);
if (need_broadcast_ && output_shape.size() > MAX_DIMS) { if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " << MAX_DIMS MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " << MAX_DIMS
<< ", but got " << output_shape.size(); << ", but got " << output_shape.size();
@ -135,18 +135,6 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod {
} }
private: private:
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
for (size_t i = 0; i < lhs.size(); i++) {
if (lhs[i] != rhs[i]) {
return true;
}
}
return false;
}
BroadcastOpType op_type_; BroadcastOpType op_type_;
bool need_broadcast_; bool need_broadcast_;
bool is_comp_op_; bool is_comp_op_;

View File

@ -105,7 +105,7 @@ void DynamicKernel::InferShape() {
tuple_elements->set_value(out_tensor); tuple_elements->set_value(out_tensor);
} }
} }
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i); common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
} }
auto eval_result = opt::CppInferShape(primitive, args_spec_list); auto eval_result = opt::CppInferShape(primitive, args_spec_list);
cnode->set_abstract(eval_result); cnode->set_abstract(eval_result);

View File

@ -1422,7 +1422,7 @@ bool AnfAlgo::IsHostKernel(const CNodePtr &kernel_node) {
} }
void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input, void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
const AnfNodePtr &real_input, size_t index) { const AnfNodePtr &real_input) {
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) { if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
// cppcheck-suppress unreadVariable // cppcheck-suppress unreadVariable
auto lock = AnfUtils::GetAbstractLock(real_input.get()); auto lock = AnfUtils::GetAbstractLock(real_input.get());