Merge pull request !31647 from baihuawei/cleancode1.7
This commit is contained in:
i-robot 2022-03-22 02:44:17 +00:00 committed by Gitee
commit 9c6fb2a67e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 10 additions and 20 deletions

View File

@ -185,6 +185,8 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) {
auto threshold = parallel_context->dp_fusion_threshold_mb();
GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index);
MS_LOG(INFO) << "The split threshold for AllReduce is " << threshold << ", the segment num is "
<< segment_index->size();
}
return CheckSegments(communication_op_node_size, segment_index);
}

View File

@ -1050,7 +1050,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
}
}
}
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
}
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
node->set_abstract(eval_result);

View File

@ -170,7 +170,7 @@ class COMMON_EXPORT AnfAlgo {
static bool IsHostKernel(const CNodePtr &node);
// return true if use cnode_input's abstract, false if use real_input's abstract
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
const AnfNodePtr &real_input, size_t index);
const AnfNodePtr &real_input);
// Find real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);

View File

@ -95,7 +95,7 @@ void KernelMod::InferShape() {
tuple_elements->set_value(out_tensor);
}
}
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
}
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
cnode->set_abstract(eval_result);

View File

@ -34,11 +34,11 @@ class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
}
~BatchMatmulFusedMulAddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private:
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
};
} // namespace opt
} // namespace mindspore

View File

@ -32,10 +32,10 @@ class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {}
~MatmulDropoutDoMaskV3AddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private:
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
};
} // namespace opt
} // namespace mindspore

View File

@ -64,7 +64,7 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod {
InitSizeLists();
return true;
}
need_broadcast_ = IsBroadcast(input_shape1, input_shape2);
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(input_shape1, input_shape2);
if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " << MAX_DIMS
<< ", but got " << output_shape.size();
@ -135,18 +135,6 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod {
}
private:
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
for (size_t i = 0; i < lhs.size(); i++) {
if (lhs[i] != rhs[i]) {
return true;
}
}
return false;
}
BroadcastOpType op_type_;
bool need_broadcast_;
bool is_comp_op_;

View File

@ -105,7 +105,7 @@ void DynamicKernel::InferShape() {
tuple_elements->set_value(out_tensor);
}
}
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
}
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
cnode->set_abstract(eval_result);

View File

@ -1422,7 +1422,7 @@ bool AnfAlgo::IsHostKernel(const CNodePtr &kernel_node) {
}
void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
const AnfNodePtr &real_input, size_t index) {
const AnfNodePtr &real_input) {
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
// cppcheck-suppress unreadVariable
auto lock = AnfUtils::GetAbstractLock(real_input.get());