forked from mindspore-Ecosystem/mindspore
!31647 clean code
Merge pull request !31647 from baihuawei/cleancode1.7
This commit is contained in:
commit
9c6fb2a67e
|
@ -185,6 +185,8 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
||||||
if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) {
|
if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) {
|
||||||
auto threshold = parallel_context->dp_fusion_threshold_mb();
|
auto threshold = parallel_context->dp_fusion_threshold_mb();
|
||||||
GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index);
|
GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index);
|
||||||
|
MS_LOG(INFO) << "The split threshold for AllReduce is " << threshold << ", the segment num is "
|
||||||
|
<< segment_index->size();
|
||||||
}
|
}
|
||||||
return CheckSegments(communication_op_node_size, segment_index);
|
return CheckSegments(communication_op_node_size, segment_index);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1050,7 +1050,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
|
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
|
||||||
}
|
}
|
||||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||||
node->set_abstract(eval_result);
|
node->set_abstract(eval_result);
|
||||||
|
|
|
@ -170,7 +170,7 @@ class COMMON_EXPORT AnfAlgo {
|
||||||
static bool IsHostKernel(const CNodePtr &node);
|
static bool IsHostKernel(const CNodePtr &node);
|
||||||
// return true if use cnode_input's abstract, false if use real_input's abstract
|
// return true if use cnode_input's abstract, false if use real_input's abstract
|
||||||
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
||||||
const AnfNodePtr &real_input, size_t index);
|
const AnfNodePtr &real_input);
|
||||||
// Find real input nodes.
|
// Find real input nodes.
|
||||||
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||||
std::set<AnfNodePtr> *visited);
|
std::set<AnfNodePtr> *visited);
|
||||||
|
|
|
@ -95,7 +95,7 @@ void KernelMod::InferShape() {
|
||||||
tuple_elements->set_value(out_tensor);
|
tuple_elements->set_value(out_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
|
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
|
||||||
}
|
}
|
||||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||||
cnode->set_abstract(eval_result);
|
cnode->set_abstract(eval_result);
|
||||||
|
|
|
@ -34,11 +34,11 @@ class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
|
||||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
|
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
|
||||||
}
|
}
|
||||||
~BatchMatmulFusedMulAddFusionPass() override = default;
|
~BatchMatmulFusedMulAddFusionPass() override = default;
|
||||||
|
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||||
FusedNodeRecord *candidate_fusion);
|
FusedNodeRecord *candidate_fusion);
|
||||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,10 +32,10 @@ class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
|
||||||
explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator)
|
explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||||
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {}
|
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {}
|
||||||
~MatmulDropoutDoMaskV3AddFusionPass() override = default;
|
~MatmulDropoutDoMaskV3AddFusionPass() override = default;
|
||||||
|
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
|
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
|
||||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -64,7 +64,7 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod {
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
need_broadcast_ = IsBroadcast(input_shape1, input_shape2);
|
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(input_shape1, input_shape2);
|
||||||
if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
|
if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " << MAX_DIMS
|
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " << MAX_DIMS
|
||||||
<< ", but got " << output_shape.size();
|
<< ", but got " << output_shape.size();
|
||||||
|
@ -135,18 +135,6 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
|
|
||||||
if (lhs.size() != rhs.size()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < lhs.size(); i++) {
|
|
||||||
if (lhs[i] != rhs[i]) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
BroadcastOpType op_type_;
|
BroadcastOpType op_type_;
|
||||||
bool need_broadcast_;
|
bool need_broadcast_;
|
||||||
bool is_comp_op_;
|
bool is_comp_op_;
|
||||||
|
|
|
@ -105,7 +105,7 @@ void DynamicKernel::InferShape() {
|
||||||
tuple_elements->set_value(out_tensor);
|
tuple_elements->set_value(out_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
|
common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input);
|
||||||
}
|
}
|
||||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||||
cnode->set_abstract(eval_result);
|
cnode->set_abstract(eval_result);
|
||||||
|
|
|
@ -1422,7 +1422,7 @@ bool AnfAlgo::IsHostKernel(const CNodePtr &kernel_node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
void AnfAlgo::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
||||||
const AnfNodePtr &real_input, size_t index) {
|
const AnfNodePtr &real_input) {
|
||||||
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
|
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
|
||||||
// cppcheck-suppress unreadVariable
|
// cppcheck-suppress unreadVariable
|
||||||
auto lock = AnfUtils::GetAbstractLock(real_input.get());
|
auto lock = AnfUtils::GetAbstractLock(real_input.get());
|
||||||
|
|
Loading…
Reference in New Issue