forked from mindspore-Ecosystem/mindspore
!34615 fuse csr ops
Merge pull request !34615 from huangmengxi/fuse_gnn
This commit is contained in:
commit
0ea5714762
|
@ -45,6 +45,7 @@
|
|||
#include "common/graph_kernel/axis_normalizer.h"
|
||||
#include "common/graph_kernel/decrease_compute_precision.h"
|
||||
#include "common/graph_kernel/decrease_transfer_precision.h"
|
||||
#include "common/graph_kernel/csr_atomic_add.h"
|
||||
#include "common/graph_kernel/tsa_atomic_add_to_first_tensor.h"
|
||||
#include "common/graph_kernel/uss_atomic_add.h"
|
||||
#include "backend/common/pass/getitem_tuple.h"
|
||||
|
@ -169,7 +170,8 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
|||
|
||||
auto &flags = GraphKernelFlags::GetInstance();
|
||||
// Auto recompute according to local memory burst.
|
||||
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0);
|
||||
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 ||
|
||||
flags.recompute_peak_threshold > 0 || flags.enable_csr_fusion);
|
||||
pm->Add(std::make_shared<GraphKernelRecompute>(), recompute_lv);
|
||||
|
||||
// Enable atomic add
|
||||
|
@ -191,6 +193,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
|||
// Enable tsa and uss
|
||||
pm->Add(std::make_shared<TsaAtomicAddToFirstTensor>(), OptLevel_1, is_gpu);
|
||||
pm->Add(std::make_shared<UssAtomicAdd>(), OptLevel_1, is_gpu);
|
||||
pm->Add(std::make_shared<CsrAtomicAdd>(), OptLevel_1, is_gpu);
|
||||
|
||||
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
|
||||
pm->Add(std::make_shared<OptimizeAssign>(), OptLevel_2);
|
||||
|
|
|
@ -45,6 +45,14 @@ class AtomicAddChecker {
|
|||
PrimitivePtr target_type_{prim::kPrimReduceSum};
|
||||
};
|
||||
|
||||
class TargetAtomicAddChecker : public AtomicAddChecker {
|
||||
public:
|
||||
explicit TargetAtomicAddChecker(const PrimitivePtr &target = prim::kPrimReduceSum) { target_type_ = target; }
|
||||
|
||||
protected:
|
||||
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override { return FindCandidate(anf_node); }
|
||||
};
|
||||
|
||||
class AtomicAddCheckerGPU : public AtomicAddChecker {
|
||||
public:
|
||||
AtomicAddCheckerGPU() = default;
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/csr_atomic_add.h"
|
||||
#include <memory>
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class ReduceSumCsrChecker : public AtomicAddChecker {
|
||||
public:
|
||||
ReduceSumCsrChecker() = default;
|
||||
|
||||
protected:
|
||||
bool CanActivateAtomicAdd(const AnfNodePtr &node) override {
|
||||
bool has_csr = false;
|
||||
bool has_reduce_sum = false;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
for (auto n : func_graph->nodes()) {
|
||||
if (n->isa<CNode>() && IsAKGSparseOP(n)) {
|
||||
has_csr = true;
|
||||
break;
|
||||
} else if (IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
|
||||
has_reduce_sum = true;
|
||||
}
|
||||
}
|
||||
if (has_csr && has_reduce_sum) {
|
||||
return FindCandidate(node);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FindCandidate(const AnfNodePtr &anf_node) override {
|
||||
atomic_add_infos_.clear();
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
auto CheckSuitableTarget = [&mng_sub](const InplaceAssignerInfo &atomic_add_info) {
|
||||
// Target type should not fuse any other ops in out direction, which means it should be in output list.
|
||||
return mng_sub->node_users()[atomic_add_info.op_node].size() <= 1;
|
||||
};
|
||||
|
||||
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
|
||||
InplaceAssignerInfo atomic_add_info;
|
||||
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
|
||||
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
atomic_add_info.op_node = inputs[i]->cast<CNodePtr>();
|
||||
atomic_add_info.real_output_index = i - 1;
|
||||
atomic_add_info.real_output_num = inputs.size() - 1;
|
||||
// Target type should not fuse any other ops in out direction, which means it should be in output list.
|
||||
if (CheckSuitableTarget(atomic_add_info)) {
|
||||
atomic_add_infos_.push_back(atomic_add_info);
|
||||
}
|
||||
}
|
||||
} else if (real_return_node->isa<CNode>()) {
|
||||
atomic_add_info.op_node = real_return_node->cast<CNodePtr>();
|
||||
atomic_add_info.real_output_num = 1;
|
||||
if (CheckSuitableTarget(atomic_add_info)) {
|
||||
atomic_add_infos_.push_back(atomic_add_info);
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return !atomic_add_infos_.empty();
|
||||
}
|
||||
};
|
||||
|
||||
bool CsrAtomicAdd::Run(const FuncGraphPtr &func_graph) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::shared_ptr<AtomicAddChecker> csr_reduce_sum_checker =
|
||||
std::make_shared<TargetAtomicAddChecker>(prim::kPrimCSRReduceSum);
|
||||
MS_EXCEPTION_IF_NULL(csr_reduce_sum_checker);
|
||||
std::shared_ptr<AtomicAddChecker> reduce_sum_csr_checker = std::make_shared<ReduceSumCsrChecker>();
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum_csr_checker);
|
||||
|
||||
auto topo_nodes = TopoSort(kernel_graph->get_return());
|
||||
for (const auto &node : topo_nodes) {
|
||||
std::vector<InplaceAssignerInfo> atomic_add_infos;
|
||||
if (csr_reduce_sum_checker->Check(node)) {
|
||||
atomic_add_infos = csr_reduce_sum_checker->GetAtomicAddInfo();
|
||||
} else if (reduce_sum_csr_checker->Check(node)) {
|
||||
atomic_add_infos = reduce_sum_csr_checker->GetAtomicAddInfo();
|
||||
}
|
||||
if (!atomic_add_infos.empty()) {
|
||||
InsertAtomicClean(kernel_graph, node, atomic_add_infos, mng);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSR_ATOMIC_ADD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSR_ATOMIC_ADD_H_
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "common/graph_kernel/add_atomic_clean.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
// Insert atomic clean node for reduce sum if any csr op is found in the graph.
|
||||
class CsrAtomicAdd : public AtomicCleanInserter {
|
||||
public:
|
||||
CsrAtomicAdd() : AtomicCleanInserter("csr_atomic_add_process") {}
|
||||
~CsrAtomicAdd() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
using CsrAtomicAddPtr = std::shared_ptr<CsrAtomicAdd>;
|
||||
} // namespace mindspore::graphkernel
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSR_ATOMIC_ADD_H_
|
|
@ -271,6 +271,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
reg.AddFlag("enable_horizontal_fusion", &enable_horizontal_fusion, false);
|
||||
reg.AddFlag("enable_auto_tensor_inplace", &enable_auto_tensor_inplace, false);
|
||||
reg.AddFlag("enable_low_precision", &enable_low_precision);
|
||||
reg.AddFlag("enable_csr_fusion", &enable_csr_fusion);
|
||||
|
||||
// Integer flags
|
||||
reg.AddFlag("online_tuning", &online_tuning);
|
||||
|
@ -304,6 +305,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
||||
json["enable_horizontal_fusion"] = enable_horizontal_fusion;
|
||||
json["enable_auto_tensor_inplace"] = enable_auto_tensor_inplace;
|
||||
json["enable_csr_fusion"] = enable_csr_fusion;
|
||||
json["enable_low_precision"] = enable_low_precision;
|
||||
|
||||
json["opt_level"] = opt_level;
|
||||
|
|
|
@ -106,6 +106,11 @@ class GraphKernelFlags {
|
|||
*/
|
||||
unsigned int fusion_ops_level{OpLevel_0};
|
||||
|
||||
/**
|
||||
* Enable recompute fusion for CSR operations.
|
||||
*/
|
||||
bool enable_csr_fusion{false};
|
||||
|
||||
/**
|
||||
* Optimization level, value from 0 to 3.
|
||||
* 0: Disable GraphKernel
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "common/graph_kernel/graph_kernel_recompute.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
|
@ -200,7 +201,9 @@ void ElimRedundantInputsAndGraphParameters(const FuncGraphPtr &func_graph, AnfNo
|
|||
std::vector<Candidate> AutoRecompute::Run(const FuncGraphPtr &func_graph) {
|
||||
lifetime_threshold_ = GraphKernelFlags::GetInstance().recompute_increment_threshold;
|
||||
local_peak_threshold_ = GraphKernelFlags::GetInstance().recompute_peak_threshold;
|
||||
FindCandidates(func_graph);
|
||||
if (!IsThresholdDefaultValue()) {
|
||||
FindCandidates(func_graph);
|
||||
}
|
||||
return candidates_;
|
||||
}
|
||||
|
||||
|
@ -420,11 +423,6 @@ void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
|
|||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
||||
// Do nothing when threshold is default value.
|
||||
if (IsThresholdDefaultValue()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto topo_nodes = TopoSort(func_graph->get_return());
|
||||
// Topo indice is use to early stop in predecessor check.
|
||||
for (size_t i = 0; i < topo_nodes.size(); ++i) {
|
||||
|
@ -469,6 +467,13 @@ AutoRecompute::NodeRecomputeCandidates AutoRecompute::FindNodeRecomputeCandidate
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
NodeRecomputeCandidates node_candidates;
|
||||
auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(graph_node);
|
||||
auto nodes = graph_node->nodes();
|
||||
if (std::any_of(nodes.cbegin(), nodes.cend(),
|
||||
[](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimReduceSum); })) {
|
||||
return node_candidates;
|
||||
}
|
||||
for (auto [gt, gt_in_pos_vec, edge_life_time_type] : target_graphs) {
|
||||
MemorySize threshold = SelectThreshold(edge_life_time_type);
|
||||
for (auto gt_in_pos : gt_in_pos_vec) {
|
||||
|
@ -505,7 +510,6 @@ AutoRecompute::NodeRecomputeCandidates AutoRecompute::FindNodeRecomputeCandidate
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return node_candidates;
|
||||
}
|
||||
|
||||
|
@ -536,6 +540,89 @@ void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candida
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<Candidate> CSRRecompute::Run(const FuncGraphPtr &func_graph) {
|
||||
FindCandidates(func_graph);
|
||||
return candidates_;
|
||||
}
|
||||
|
||||
bool CSRRecompute::CheckPrimitiveInput(AnfNodePtr base, PrimitivePtr prim_type) {
|
||||
std::deque<AnfNodePtr> q{base};
|
||||
std::set<AnfNodePtr> visited;
|
||||
while (!q.empty()) {
|
||||
auto node = q.front();
|
||||
q.pop_front();
|
||||
if (visited.count(node) > 0) continue;
|
||||
visited.insert(node);
|
||||
if (IsPrimitiveCNode(node, prim_type)) {
|
||||
return true;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) continue;
|
||||
auto inputs = cnode->inputs();
|
||||
q.insert(q.begin(), inputs.begin(), inputs.end());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AutoRecompute::NodeRecomputeCandidates CSRRecompute::FindNodeRecomputeCandidates(const AnfNodePtr &node,
|
||||
const OutPosLinkList &target_graphs,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
NodeRecomputeCandidates node_candidates;
|
||||
auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(graph_node);
|
||||
auto nodes = graph_node->nodes();
|
||||
// subgraphs outputting UnsortedSegmentSum or CSRReduceSum along with other ops
|
||||
// (likely the result of Gather), or containing CSRDiv without outputting
|
||||
// UnsortedSegmentSum or CSRReduceSum, are selected as candidates for recompute.
|
||||
auto TargetTail = [](const AnfNodePtr n) {
|
||||
return IsPrimitiveCNode(n, prim::kPrimUnsortedSegmentSum) || IsPrimitiveCNode(n, prim::kPrimCSRReduceSum);
|
||||
};
|
||||
auto TargetHead = [](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimCSRDiv); };
|
||||
auto return_node = graph_node->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
auto return_cnode = return_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(return_cnode);
|
||||
auto return_inputs = return_cnode->inputs();
|
||||
auto return_tup = return_inputs[1]->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(return_tup);
|
||||
auto tuple_inputs = return_tup->inputs();
|
||||
std::set<size_t> candidate_idx;
|
||||
if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetTail)) {
|
||||
for (size_t i = 1; i < tuple_inputs.size(); ++i) {
|
||||
if (!TargetTail(tuple_inputs[i])) {
|
||||
candidate_idx.insert(i - 1);
|
||||
}
|
||||
}
|
||||
} else if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetHead)) {
|
||||
for (size_t i = 1; i < tuple_inputs.size(); ++i) {
|
||||
if (CheckPrimitiveInput(tuple_inputs[i], prim::kPrimCSRDiv)) {
|
||||
candidate_idx.insert(i - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (candidate_idx.empty()) return node_candidates;
|
||||
for (size_t i = 0; i < target_graphs.size(); ++i) {
|
||||
AnfNodePtr gt;
|
||||
std::vector<int> gt_in_pos_vec;
|
||||
std::tie(gt, gt_in_pos_vec, std::ignore) = target_graphs[i];
|
||||
for (auto gt_in_pos : gt_in_pos_vec) {
|
||||
auto gt_cnode = gt->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(gt_cnode);
|
||||
auto edge = gt_cnode->input(IntToSize(gt_in_pos));
|
||||
if (!IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) continue;
|
||||
auto edge_cnode = edge->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(edge_cnode);
|
||||
auto tuple_idx = common::AnfAlgo::GetTupleGetItemOutIndex(edge_cnode);
|
||||
if (candidate_idx.count(tuple_idx) > 0) {
|
||||
node_candidates[node][gt].second.push_back(edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
return node_candidates;
|
||||
}
|
||||
|
||||
std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
|
||||
const AnfNodePtrList &recompute_edges) {
|
||||
MS_EXCEPTION_IF_NULL(source_graph);
|
||||
|
@ -675,11 +762,14 @@ void GraphKernelRecompute::Process(const Candidate &candidate) {
|
|||
}
|
||||
}
|
||||
|
||||
bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
|
||||
bool GraphKernelRecompute::DoRun(const FuncGraphPtr &func_graph, bool use_csr) {
|
||||
int repeat_times = 2;
|
||||
while ((repeat_times--) != 0) {
|
||||
AutoRecompute recompute;
|
||||
candidates_ = recompute.Run(func_graph);
|
||||
if (use_csr) {
|
||||
candidates_ = CSRRecompute().Run(func_graph);
|
||||
} else {
|
||||
candidates_ = AutoRecompute().Run(func_graph);
|
||||
}
|
||||
if (candidates_.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -705,4 +795,12 @@ bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
|
||||
bool status = DoRun(func_graph);
|
||||
if (GraphKernelFlags::GetInstance().enable_csr_fusion) {
|
||||
status |= DoRun(func_graph, true);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -69,17 +69,21 @@ class AutoRecompute {
|
|||
AutoRecompute() = default;
|
||||
~AutoRecompute() = default;
|
||||
|
||||
std::vector<Candidate> Run(const FuncGraphPtr &func_graph);
|
||||
virtual std::vector<Candidate> Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
protected:
|
||||
using NodeRecomputeCandidates =
|
||||
OrderedMap<AnfNodePtr, OrderedMap<AnfNodePtr, std::pair<EdgeLifeTimeType, AnfNodePtrList>>>;
|
||||
virtual NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node,
|
||||
const OutPosLinkList &target_graphs,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
void FindCandidates(const FuncGraphPtr &func_graph);
|
||||
std::vector<Candidate> candidates_;
|
||||
|
||||
private:
|
||||
OutPosLinkList JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
|
||||
AnfNodePtrList Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node, const OutPosLinkList &target_graphs,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
void FindCandidates(const FuncGraphPtr &func_graph);
|
||||
int GetSourceLinkOutPos(const AnfNodePtr &target, int pos);
|
||||
std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> GetValidUsers(const AnfNodePtr &node,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
|
@ -87,7 +91,6 @@ class AutoRecompute {
|
|||
bool IsThresholdDefaultValue() const;
|
||||
|
||||
std::map<AnfNodePtr, MemorySize> topo_indice_;
|
||||
std::vector<Candidate> candidates_;
|
||||
MemorySize lifetime_threshold_{0};
|
||||
MemorySize local_peak_threshold_{0};
|
||||
|
||||
|
@ -96,6 +99,18 @@ class AutoRecompute {
|
|||
void RecomputeCandidatesLog(const std::vector<Candidate> &candidates) const;
|
||||
};
|
||||
|
||||
class CSRRecompute : public AutoRecompute {
|
||||
public:
|
||||
std::vector<Candidate> Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node, const OutPosLinkList &target_graphs,
|
||||
const FuncGraphManagerPtr &mng) override;
|
||||
|
||||
private:
|
||||
bool CheckPrimitiveInput(AnfNodePtr base, PrimitivePtr prim_type);
|
||||
};
|
||||
|
||||
class GraphKernelRecompute : public opt::Pass {
|
||||
public:
|
||||
GraphKernelRecompute() : Pass("graph_kernel_recompute") {}
|
||||
|
@ -103,6 +118,7 @@ class GraphKernelRecompute : public opt::Pass {
|
|||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool DoRun(const FuncGraphPtr &func_graph, bool use_csr = false);
|
||||
void Process(const Candidate &candidate);
|
||||
std::pair<FuncGraphPtr, AnfNodePtrList> CloneGraph(const CNodePtr &source_graph,
|
||||
const AnfNodePtrList &recompute_edges);
|
||||
|
|
|
@ -28,7 +28,7 @@ std::string CommonDimInfo::ToString() {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
|
||||
int64_t ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
|
||||
nlohmann::json json_desc;
|
||||
AnfNodePtrList nodes = {node};
|
||||
DumpOption dump_option;
|
||||
|
@ -38,7 +38,7 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
|
|||
|
||||
auto json_desc_str = json_desc.dump();
|
||||
auto ret = python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str);
|
||||
auto bottleneck = py::cast<int>(ret);
|
||||
auto bottleneck = py::cast<int64_t>(ret);
|
||||
if (bottleneck == -1) {
|
||||
MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelGetNodeCalAmount << "] return invalid result. input json:\n"
|
||||
<< json_desc_str;
|
||||
|
|
|
@ -95,7 +95,7 @@ class ParallelCostModel {
|
|||
public:
|
||||
ParallelCostModel() {}
|
||||
~ParallelCostModel() {}
|
||||
int GetNodeCalAmount(const AnfNodePtr &node) const;
|
||||
int64_t GetNodeCalAmount(const AnfNodePtr &node) const;
|
||||
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -632,7 +632,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::Searc
|
|||
}
|
||||
|
||||
// A calculated heavy node can cover more lighter nodes' cost, so sort them first.
|
||||
std::map<size_t, int> cal_amounts;
|
||||
std::map<size_t, int64_t> cal_amounts;
|
||||
for (auto id : indices) {
|
||||
cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
|
||||
}
|
||||
|
|
|
@ -25,15 +25,6 @@
|
|||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class UssChecker : public AtomicAddChecker {
|
||||
public:
|
||||
explicit UssChecker(const PrimitivePtr &target) { target_type_ = target; }
|
||||
virtual ~UssChecker() = default;
|
||||
|
||||
protected:
|
||||
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override { return FindCandidate(anf_node); }
|
||||
};
|
||||
|
||||
bool UssAtomicAdd::Run(const FuncGraphPtr &func_graph) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -45,7 +36,7 @@ bool UssAtomicAdd::Run(const FuncGraphPtr &func_graph) {
|
|||
|
||||
bool has_change = false;
|
||||
std::shared_ptr<AtomicAddChecker> checker =
|
||||
std::make_shared<UssChecker>(std::make_shared<Primitive>("UnsortedSegmentSum"));
|
||||
std::make_shared<TargetAtomicAddChecker>(std::make_shared<Primitive>("UnsortedSegmentSum"));
|
||||
if (checker == nullptr) {
|
||||
return has_change;
|
||||
}
|
||||
|
|
|
@ -569,6 +569,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
|
||||
constexpr size_t csr_row_num = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kSizeFour);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
|
||||
|
@ -592,6 +593,11 @@ AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
|
||||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
ShapeVector out_shape = indices->shape()->shape();
|
||||
MS_EXCEPTION_IF_NULL(dense->shape());
|
||||
ShapeVector dense_shape = dense->shape()->shape();
|
||||
for (size_t i = csr_row_num; i < dense_shape.size(); ++i) {
|
||||
out_shape.push_back(dense_shape[i]);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dense->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
|
|
|
@ -1016,7 +1016,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return fused, True
|
||||
|
||||
def _gather_output(dom, reduce_fusion=False):
|
||||
gather_prims = ("Gather", "GatherNd")
|
||||
gather_prims = ("Gather", "GatherNd", "CSRGather")
|
||||
if not dom.dom_op().prim in gather_prims:
|
||||
return []
|
||||
|
||||
|
@ -1049,14 +1049,22 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
visited = []
|
||||
op_queue = [start_op]
|
||||
|
||||
def _remove_preceding_ones(shape):
|
||||
i = 0
|
||||
while shape[i] == 1 and i < len(shape):
|
||||
i += 1
|
||||
return shape[i:]
|
||||
|
||||
def _early_stop(cur_op):
|
||||
if cur_op in end_ops:
|
||||
# If reduce the gather axis, stop early for not fusion.
|
||||
if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis):
|
||||
return True
|
||||
else:
|
||||
shape1 = _remove_preceding_ones(consisten_shape)
|
||||
shape2 = _remove_preceding_ones(cur_op.output.shape)
|
||||
if (cur_op.prim in start_prims and cur_op != start_op) or \
|
||||
consisten_shape != cur_op.output.shape:
|
||||
shape1 != shape2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -1108,7 +1116,10 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return False
|
||||
return True
|
||||
|
||||
appected_areas = {"ReduceSum"} if reduce_fusion else {"TensorScatterAdd", "UnsortedSegmentSum"}
|
||||
if reduce_fusion:
|
||||
appected_areas = {"ReduceSum", "CSRReduceSum"}
|
||||
else:
|
||||
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
|
||||
|
||||
for a, _ in dom.out_relations.items():
|
||||
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
|
||||
|
@ -1206,12 +1217,30 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
|
||||
|
||||
def _link_csr(dom):
|
||||
def _same_input(op1, op2):
|
||||
return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
|
||||
fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
|
||||
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
|
||||
if arg_idx == -1:
|
||||
return []
|
||||
fuse_tensor = dom.dom_op().inputs[arg_idx]
|
||||
for a, _ in dom.in_relations.items():
|
||||
if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
|
||||
_same_input(dom.dom_op(), a.dom_op())):
|
||||
return [a], True
|
||||
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
|
||||
any([op.output in fuse_tensor for op in a.ops]):
|
||||
return [a], True
|
||||
return []
|
||||
|
||||
def _fuse_loop():
|
||||
self.fuse(CommonPattern.reshape)
|
||||
self.fuse(CommonPattern.assign)
|
||||
self.fuse(CommonPattern.elemwise_depth)
|
||||
self.fuse(CommonPattern.elemwise_width)
|
||||
self.fuse(_broadcast_tot)
|
||||
self.fuse(_link_csr)
|
||||
self.fuse(CommonPattern.broadcast_depth)
|
||||
self.fuse(CommonPattern.broadcast_width)
|
||||
self.fuse(_reduce_depth)
|
||||
|
|
|
@ -229,6 +229,10 @@ class PrimLib:
|
|||
'TensorScatterAdd': Prim(OPAQUE),
|
||||
'Gather': Prim(OPAQUE),
|
||||
'GatherNd': Prim(OPAQUE),
|
||||
'CSRReduceSum': Prim(OPAQUE),
|
||||
'CSRMul': Prim(ELEMWISE),
|
||||
'CSRDiv': Prim(ELEMWISE),
|
||||
'CSRGather': Prim(OPAQUE),
|
||||
'UnsortedSegmentSum': Prim(OPAQUE),
|
||||
'StandardNormal': Prim(OPAQUE),
|
||||
'OneHot': Prim(OPAQUE),
|
||||
|
|
|
@ -311,7 +311,7 @@ def test_batch_csr_ops():
|
|||
assert np.allclose(graph_res_elemwise[0].values.asnumpy(), expect3)
|
||||
assert np.allclose(graph_res_elemwise[1].values.asnumpy(), expect4)
|
||||
|
||||
expect5 = np.array([1., 1.], dtype=np.float32)
|
||||
expect5 = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32)
|
||||
res_gather = test_ops_pynative_gather()
|
||||
test_ops_graph_gather = ms_function(test_ops_pynative_gather)
|
||||
graph_res_gather = test_ops_graph_gather()
|
||||
|
|
Loading…
Reference in New Issue