forked from mindspore-Ecosystem/mindspore
!23676 [GraphKernel] Limit the number of inputs and outputs for parallel fusion in CUDA.
Merge pull request !23676 from TronZhang/add_parallel_limit
This commit is contained in:
commit
f482e2a5fc
|
@ -20,6 +20,7 @@
|
|||
#include <list>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
@ -28,6 +29,8 @@
|
|||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
// Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512.
|
||||
constexpr size_t CUDA_PARA_LIMIT = 512;
|
||||
bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) {
|
||||
return std::any_of(ops_prim.cbegin(), ops_prim.cend(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
|
@ -384,6 +387,31 @@ void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &ta
|
|||
<< "(" << DumpNode(target) << ")";
|
||||
MS_LOG(INFO) << buf.str();
|
||||
}
|
||||
|
||||
inline bool ParameterLimit(const AnfNodePtrList &nodes) {
|
||||
if (nodes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition.";
|
||||
}
|
||||
|
||||
bool res = true;
|
||||
switch (AnfAlgo::GetProcessor(nodes[0])) {
|
||||
case kernel::Processor::CUDA: {
|
||||
// The number of inputs and outputs for a valid kernel should be less than cuda's limit.
|
||||
size_t para_count = 0;
|
||||
for (const auto &node : nodes) {
|
||||
para_count += AnfAlgo::GetInputTensorNum(node);
|
||||
para_count += AnfAlgo::GetOutputTensorNum(node);
|
||||
}
|
||||
res = para_count <= CUDA_PARA_LIMIT;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); }
|
||||
} // namespace
|
||||
|
||||
OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
|
||||
|
@ -513,13 +541,15 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea
|
|||
AnfNodePtrList other_candidates;
|
||||
std::tie(other_candidates, std::ignore) =
|
||||
GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
|
||||
int benefit;
|
||||
std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
|
||||
if (benefit > 0) {
|
||||
begin = mid + 1;
|
||||
} else {
|
||||
end = mid - 1;
|
||||
if (ExtraFusionCondition(other_candidates)) {
|
||||
int benefit;
|
||||
std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
|
||||
if (benefit > 0) {
|
||||
begin = mid + 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
end = mid - 1;
|
||||
}
|
||||
|
||||
if (begin > 1) {
|
||||
|
|
Loading…
Reference in New Issue