diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc index 97f418e1bf6..e81b9d3f9fd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc @@ -20,6 +20,7 @@ #include #include #include +#include "backend/kernel_compiler/kernel.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "frontend/operator/ops.h" #include "ir/func_graph_cloner.h" @@ -28,6 +29,8 @@ namespace mindspore::graphkernel { namespace { +// Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512. +constexpr size_t CUDA_PARA_LIMIT = 512; bool IsOneOf(const AnfNodePtr &node, const std::vector &ops_prim) { return std::any_of(ops_prim.cbegin(), ops_prim.cend(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); @@ -384,6 +387,31 @@ void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &ta << "(" << DumpNode(target) << ")"; MS_LOG(INFO) << buf.str(); } + +inline bool ParameterLimit(const AnfNodePtrList &nodes) { + if (nodes.empty()) { + MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition."; + } + + bool res = true; + switch (AnfAlgo::GetProcessor(nodes[0])) { + case kernel::Processor::CUDA: { + // The number of inputs and outputs for a valid kernel should be less than cuda's limit. + size_t para_count = 0; + for (const auto &node : nodes) { + para_count += AnfAlgo::GetInputTensorNum(node); + para_count += AnfAlgo::GetOutputTensorNum(node); + } + res = para_count <= CUDA_PARA_LIMIT; + } break; + default: + break; + } + + return res; +} + +bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); } } // namespace OrderedMap ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) { @@ -513,13 +541,15 @@ std::tuple, std::vector> ParallelOpFusion::DoSea AnfNodePtrList other_candidates; std::tie(other_candidates, std::ignore) = GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set()); - int benefit; - std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); - if (benefit > 0) { - begin = mid + 1; - } else { - end = mid - 1; + if (ExtraFusionCondition(other_candidates)) { + int benefit; + std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); + if (benefit > 0) { + begin = mid + 1; + continue; + } } + end = mid - 1; } if (begin > 1) {