!34615 fuse csr ops

Merge pull request !34615 from huangmengxi/fuse_gnn
This commit is contained in:
i-robot 2022-06-10 09:09:06 +00:00 committed by Gitee
commit 0ea5714762
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 369 additions and 35 deletions

View File

@ -45,6 +45,7 @@
#include "common/graph_kernel/axis_normalizer.h"
#include "common/graph_kernel/decrease_compute_precision.h"
#include "common/graph_kernel/decrease_transfer_precision.h"
#include "common/graph_kernel/csr_atomic_add.h"
#include "common/graph_kernel/tsa_atomic_add_to_first_tensor.h"
#include "common/graph_kernel/uss_atomic_add.h"
#include "backend/common/pass/getitem_tuple.h"
@ -169,7 +170,8 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
auto &flags = GraphKernelFlags::GetInstance();
// Auto recompute according to local memory burst.
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0);
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 ||
flags.recompute_peak_threshold > 0 || flags.enable_csr_fusion);
pm->Add(std::make_shared<GraphKernelRecompute>(), recompute_lv);
// Enable atomic add
@ -191,6 +193,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
// Enable tsa and uss
pm->Add(std::make_shared<TsaAtomicAddToFirstTensor>(), OptLevel_1, is_gpu);
pm->Add(std::make_shared<UssAtomicAdd>(), OptLevel_1, is_gpu);
pm->Add(std::make_shared<CsrAtomicAdd>(), OptLevel_1, is_gpu);
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
pm->Add(std::make_shared<OptimizeAssign>(), OptLevel_2);

View File

@ -45,6 +45,14 @@ class AtomicAddChecker {
PrimitivePtr target_type_{prim::kPrimReduceSum};
};
class TargetAtomicAddChecker : public AtomicAddChecker {
public:
explicit TargetAtomicAddChecker(const PrimitivePtr &target = prim::kPrimReduceSum) { target_type_ = target; }
protected:
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override { return FindCandidate(anf_node); }
};
class AtomicAddCheckerGPU : public AtomicAddChecker {
public:
AtomicAddCheckerGPU() = default;

View File

@ -0,0 +1,133 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/csr_atomic_add.h"
#include <memory>
#include "mindspore/core/ops/core_ops.h"
#include "ir/tensor.h"
#include "include/common/utils/utils.h"
#include "kernel/kernel.h"
#include "kernel/common_utils.h"
#include "common/graph_kernel/graph_kernel_helper.h"
#include "backend/common/session/kernel_graph.h"
namespace mindspore::graphkernel {
class ReduceSumCsrChecker : public AtomicAddChecker {
public:
ReduceSumCsrChecker() = default;
protected:
bool CanActivateAtomicAdd(const AnfNodePtr &node) override {
bool has_csr = false;
bool has_reduce_sum = false;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(func_graph);
for (auto n : func_graph->nodes()) {
if (n->isa<CNode>() && IsAKGSparseOP(n)) {
has_csr = true;
break;
} else if (IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
has_reduce_sum = true;
}
}
if (has_csr && has_reduce_sum) {
return FindCandidate(node);
}
return false;
}
bool FindCandidate(const AnfNodePtr &anf_node) override {
atomic_add_infos_.clear();
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
mng_sub = Manage(sub_graph, false);
sub_graph->set_manager(mng_sub);
}
auto CheckSuitableTarget = [&mng_sub](const InplaceAssignerInfo &atomic_add_info) {
// Target type should not fuse any other ops in out direction, which means it should be in output list.
return mng_sub->node_users()[atomic_add_info.op_node].size() <= 1;
};
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
InplaceAssignerInfo atomic_add_info;
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
atomic_add_info.op_node = inputs[i]->cast<CNodePtr>();
atomic_add_info.real_output_index = i - 1;
atomic_add_info.real_output_num = inputs.size() - 1;
// Target type should not fuse any other ops in out direction, which means it should be in output list.
if (CheckSuitableTarget(atomic_add_info)) {
atomic_add_infos_.push_back(atomic_add_info);
}
}
} else if (real_return_node->isa<CNode>()) {
atomic_add_info.op_node = real_return_node->cast<CNodePtr>();
atomic_add_info.real_output_num = 1;
if (CheckSuitableTarget(atomic_add_info)) {
atomic_add_infos_.push_back(atomic_add_info);
}
} else {
return false;
}
return !atomic_add_infos_.empty();
}
};
bool CsrAtomicAdd::Run(const FuncGraphPtr &func_graph) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
bool changed = false;
std::shared_ptr<AtomicAddChecker> csr_reduce_sum_checker =
std::make_shared<TargetAtomicAddChecker>(prim::kPrimCSRReduceSum);
MS_EXCEPTION_IF_NULL(csr_reduce_sum_checker);
std::shared_ptr<AtomicAddChecker> reduce_sum_csr_checker = std::make_shared<ReduceSumCsrChecker>();
MS_EXCEPTION_IF_NULL(reduce_sum_csr_checker);
auto topo_nodes = TopoSort(kernel_graph->get_return());
for (const auto &node : topo_nodes) {
std::vector<InplaceAssignerInfo> atomic_add_infos;
if (csr_reduce_sum_checker->Check(node)) {
atomic_add_infos = csr_reduce_sum_checker->GetAtomicAddInfo();
} else if (reduce_sum_csr_checker->Check(node)) {
atomic_add_infos = reduce_sum_csr_checker->GetAtomicAddInfo();
}
if (!atomic_add_infos.empty()) {
InsertAtomicClean(kernel_graph, node, atomic_add_infos, mng);
changed = true;
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSR_ATOMIC_ADD_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSR_ATOMIC_ADD_H_
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
#include "backend/common/optimizer/optimizer.h"
#include "common/graph_kernel/add_atomic_clean.h"
#include "backend/common/session/kernel_graph.h"
namespace mindspore::graphkernel {
// Insert atomic clean node for reduce sum if any csr op is found in the graph.
class CsrAtomicAdd : public AtomicCleanInserter {
public:
CsrAtomicAdd() : AtomicCleanInserter("csr_atomic_add_process") {}
~CsrAtomicAdd() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using CsrAtomicAddPtr = std::shared_ptr<CsrAtomicAdd>;
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSR_ATOMIC_ADD_H_

View File

@ -271,6 +271,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("enable_horizontal_fusion", &enable_horizontal_fusion, false);
reg.AddFlag("enable_auto_tensor_inplace", &enable_auto_tensor_inplace, false);
reg.AddFlag("enable_low_precision", &enable_low_precision);
reg.AddFlag("enable_csr_fusion", &enable_csr_fusion);
// Integer flags
reg.AddFlag("online_tuning", &online_tuning);
@ -304,6 +305,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["enable_parallel_fusion"] = enable_parallel_fusion;
json["enable_horizontal_fusion"] = enable_horizontal_fusion;
json["enable_auto_tensor_inplace"] = enable_auto_tensor_inplace;
json["enable_csr_fusion"] = enable_csr_fusion;
json["enable_low_precision"] = enable_low_precision;
json["opt_level"] = opt_level;

View File

@ -106,6 +106,11 @@ class GraphKernelFlags {
*/
unsigned int fusion_ops_level{OpLevel_0};
/**
* Enable recompute fusion for CSR operations.
*/
bool enable_csr_fusion{false};
/**
* Optimization level, value from 0 to 3.
* 0: Disable GraphKernel

View File

@ -17,6 +17,7 @@
#include "common/graph_kernel/graph_kernel_recompute.h"
#include <algorithm>
#include <deque>
#include <functional>
#include <limits>
#include <map>
@ -200,7 +201,9 @@ void ElimRedundantInputsAndGraphParameters(const FuncGraphPtr &func_graph, AnfNo
std::vector<Candidate> AutoRecompute::Run(const FuncGraphPtr &func_graph) {
lifetime_threshold_ = GraphKernelFlags::GetInstance().recompute_increment_threshold;
local_peak_threshold_ = GraphKernelFlags::GetInstance().recompute_peak_threshold;
FindCandidates(func_graph);
if (!IsThresholdDefaultValue()) {
FindCandidates(func_graph);
}
return candidates_;
}
@ -420,11 +423,6 @@ void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
// Do nothing when threshold is default value.
if (IsThresholdDefaultValue()) {
return;
}
auto topo_nodes = TopoSort(func_graph->get_return());
// Topo indice is use to early stop in predecessor check.
for (size_t i = 0; i < topo_nodes.size(); ++i) {
@ -469,6 +467,13 @@ AutoRecompute::NodeRecomputeCandidates AutoRecompute::FindNodeRecomputeCandidate
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mng);
NodeRecomputeCandidates node_candidates;
auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(graph_node);
auto nodes = graph_node->nodes();
if (std::any_of(nodes.cbegin(), nodes.cend(),
[](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimReduceSum); })) {
return node_candidates;
}
for (auto [gt, gt_in_pos_vec, edge_life_time_type] : target_graphs) {
MemorySize threshold = SelectThreshold(edge_life_time_type);
for (auto gt_in_pos : gt_in_pos_vec) {
@ -505,7 +510,6 @@ AutoRecompute::NodeRecomputeCandidates AutoRecompute::FindNodeRecomputeCandidate
}
}
}
return node_candidates;
}
@ -536,6 +540,89 @@ void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candida
}
}
std::vector<Candidate> CSRRecompute::Run(const FuncGraphPtr &func_graph) {
FindCandidates(func_graph);
return candidates_;
}
bool CSRRecompute::CheckPrimitiveInput(AnfNodePtr base, PrimitivePtr prim_type) {
std::deque<AnfNodePtr> q{base};
std::set<AnfNodePtr> visited;
while (!q.empty()) {
auto node = q.front();
q.pop_front();
if (visited.count(node) > 0) continue;
visited.insert(node);
if (IsPrimitiveCNode(node, prim_type)) {
return true;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) continue;
auto inputs = cnode->inputs();
q.insert(q.begin(), inputs.begin(), inputs.end());
}
return false;
}
AutoRecompute::NodeRecomputeCandidates CSRRecompute::FindNodeRecomputeCandidates(const AnfNodePtr &node,
const OutPosLinkList &target_graphs,
const FuncGraphManagerPtr &mng) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mng);
NodeRecomputeCandidates node_candidates;
auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(graph_node);
auto nodes = graph_node->nodes();
// subgraphs outputting UnsortedSegmentSum or CSRReduceSum along with other ops
// (likely the result of Gather), or containing CSRDiv without outputting
// UnsortedSegmentSum or CSRReduceSum, are selected as candidates for recompute.
auto TargetTail = [](const AnfNodePtr n) {
return IsPrimitiveCNode(n, prim::kPrimUnsortedSegmentSum) || IsPrimitiveCNode(n, prim::kPrimCSRReduceSum);
};
auto TargetHead = [](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimCSRDiv); };
auto return_node = graph_node->get_return();
MS_EXCEPTION_IF_NULL(return_node);
auto return_cnode = return_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_cnode);
auto return_inputs = return_cnode->inputs();
auto return_tup = return_inputs[1]->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_tup);
auto tuple_inputs = return_tup->inputs();
std::set<size_t> candidate_idx;
if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetTail)) {
for (size_t i = 1; i < tuple_inputs.size(); ++i) {
if (!TargetTail(tuple_inputs[i])) {
candidate_idx.insert(i - 1);
}
}
} else if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetHead)) {
for (size_t i = 1; i < tuple_inputs.size(); ++i) {
if (CheckPrimitiveInput(tuple_inputs[i], prim::kPrimCSRDiv)) {
candidate_idx.insert(i - 1);
}
}
}
if (candidate_idx.empty()) return node_candidates;
for (size_t i = 0; i < target_graphs.size(); ++i) {
AnfNodePtr gt;
std::vector<int> gt_in_pos_vec;
std::tie(gt, gt_in_pos_vec, std::ignore) = target_graphs[i];
for (auto gt_in_pos : gt_in_pos_vec) {
auto gt_cnode = gt->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(gt_cnode);
auto edge = gt_cnode->input(IntToSize(gt_in_pos));
if (!IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) continue;
auto edge_cnode = edge->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(edge_cnode);
auto tuple_idx = common::AnfAlgo::GetTupleGetItemOutIndex(edge_cnode);
if (candidate_idx.count(tuple_idx) > 0) {
node_candidates[node][gt].second.push_back(edge);
}
}
}
return node_candidates;
}
std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
const AnfNodePtrList &recompute_edges) {
MS_EXCEPTION_IF_NULL(source_graph);
@ -675,11 +762,14 @@ void GraphKernelRecompute::Process(const Candidate &candidate) {
}
}
bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
bool GraphKernelRecompute::DoRun(const FuncGraphPtr &func_graph, bool use_csr) {
int repeat_times = 2;
while ((repeat_times--) != 0) {
AutoRecompute recompute;
candidates_ = recompute.Run(func_graph);
if (use_csr) {
candidates_ = CSRRecompute().Run(func_graph);
} else {
candidates_ = AutoRecompute().Run(func_graph);
}
if (candidates_.empty()) {
return false;
}
@ -705,4 +795,12 @@ bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
}
return true;
}
bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
bool status = DoRun(func_graph);
if (GraphKernelFlags::GetInstance().enable_csr_fusion) {
status |= DoRun(func_graph, true);
}
return status;
}
} // namespace mindspore::graphkernel

View File

@ -69,17 +69,21 @@ class AutoRecompute {
AutoRecompute() = default;
~AutoRecompute() = default;
std::vector<Candidate> Run(const FuncGraphPtr &func_graph);
virtual std::vector<Candidate> Run(const FuncGraphPtr &func_graph);
private:
protected:
using NodeRecomputeCandidates =
OrderedMap<AnfNodePtr, OrderedMap<AnfNodePtr, std::pair<EdgeLifeTimeType, AnfNodePtrList>>>;
virtual NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node,
const OutPosLinkList &target_graphs,
const FuncGraphManagerPtr &mng);
void FindCandidates(const FuncGraphPtr &func_graph);
std::vector<Candidate> candidates_;
private:
OutPosLinkList JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
AnfNodePtrList Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
const FuncGraphManagerPtr &mng);
NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node, const OutPosLinkList &target_graphs,
const FuncGraphManagerPtr &mng);
void FindCandidates(const FuncGraphPtr &func_graph);
int GetSourceLinkOutPos(const AnfNodePtr &target, int pos);
std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> GetValidUsers(const AnfNodePtr &node,
const FuncGraphManagerPtr &mng);
@ -87,7 +91,6 @@ class AutoRecompute {
bool IsThresholdDefaultValue() const;
std::map<AnfNodePtr, MemorySize> topo_indice_;
std::vector<Candidate> candidates_;
MemorySize lifetime_threshold_{0};
MemorySize local_peak_threshold_{0};
@ -96,6 +99,18 @@ class AutoRecompute {
void RecomputeCandidatesLog(const std::vector<Candidate> &candidates) const;
};
class CSRRecompute : public AutoRecompute {
public:
std::vector<Candidate> Run(const FuncGraphPtr &func_graph) override;
protected:
NodeRecomputeCandidates FindNodeRecomputeCandidates(const AnfNodePtr &node, const OutPosLinkList &target_graphs,
const FuncGraphManagerPtr &mng) override;
private:
bool CheckPrimitiveInput(AnfNodePtr base, PrimitivePtr prim_type);
};
class GraphKernelRecompute : public opt::Pass {
public:
GraphKernelRecompute() : Pass("graph_kernel_recompute") {}
@ -103,6 +118,7 @@ class GraphKernelRecompute : public opt::Pass {
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool DoRun(const FuncGraphPtr &func_graph, bool use_csr = false);
void Process(const Candidate &candidate);
std::pair<FuncGraphPtr, AnfNodePtrList> CloneGraph(const CNodePtr &source_graph,
const AnfNodePtrList &recompute_edges);

View File

@ -28,7 +28,7 @@ std::string CommonDimInfo::ToString() {
return buffer.str();
}
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
int64_t ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
nlohmann::json json_desc;
AnfNodePtrList nodes = {node};
DumpOption dump_option;
@ -38,7 +38,7 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
auto json_desc_str = json_desc.dump();
auto ret = python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str);
auto bottleneck = py::cast<int>(ret);
auto bottleneck = py::cast<int64_t>(ret);
if (bottleneck == -1) {
MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelGetNodeCalAmount << "] return invalid result. input json:\n"
<< json_desc_str;

View File

@ -95,7 +95,7 @@ class ParallelCostModel {
public:
ParallelCostModel() {}
~ParallelCostModel() {}
int GetNodeCalAmount(const AnfNodePtr &node) const;
int64_t GetNodeCalAmount(const AnfNodePtr &node) const;
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const;
private:

View File

@ -632,7 +632,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::Searc
}
// A calculated heavy node can cover more lighter nodes' cost, so sort them first.
std::map<size_t, int> cal_amounts;
std::map<size_t, int64_t> cal_amounts;
for (auto id : indices) {
cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
}

View File

@ -25,15 +25,6 @@
#include "backend/common/session/kernel_graph.h"
namespace mindspore::graphkernel {
class UssChecker : public AtomicAddChecker {
public:
explicit UssChecker(const PrimitivePtr &target) { target_type_ = target; }
virtual ~UssChecker() = default;
protected:
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override { return FindCandidate(anf_node); }
};
bool UssAtomicAdd::Run(const FuncGraphPtr &func_graph) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -45,7 +36,7 @@ bool UssAtomicAdd::Run(const FuncGraphPtr &func_graph) {
bool has_change = false;
std::shared_ptr<AtomicAddChecker> checker =
std::make_shared<UssChecker>(std::make_shared<Primitive>("UnsortedSegmentSum"));
std::make_shared<TargetAtomicAddChecker>(std::make_shared<Primitive>("UnsortedSegmentSum"));
if (checker == nullptr) {
return has_change;
}

View File

@ -569,6 +569,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
constexpr size_t csr_row_num = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kSizeFour);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
@ -592,6 +593,11 @@ AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr
MS_EXCEPTION_IF_NULL(indices->shape());
ShapeVector out_shape = indices->shape()->shape();
MS_EXCEPTION_IF_NULL(dense->shape());
ShapeVector dense_shape = dense->shape()->shape();
for (size_t i = csr_row_num; i < dense_shape.size(); ++i) {
out_shape.push_back(dense_shape[i]);
}
MS_EXCEPTION_IF_NULL(dense->element());
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
return ret;

View File

@ -1016,7 +1016,7 @@ class GraphSplitGpu(GraphSplitByPattern):
return fused, True
def _gather_output(dom, reduce_fusion=False):
gather_prims = ("Gather", "GatherNd")
gather_prims = ("Gather", "GatherNd", "CSRGather")
if not dom.dom_op().prim in gather_prims:
return []
@ -1049,14 +1049,22 @@ class GraphSplitGpu(GraphSplitByPattern):
visited = []
op_queue = [start_op]
def _remove_preceding_ones(shape):
i = 0
while shape[i] == 1 and i < len(shape):
i += 1
return shape[i:]
def _early_stop(cur_op):
if cur_op in end_ops:
# If reduce the gather axis, stop early for not fusion.
if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis):
return True
else:
shape1 = _remove_preceding_ones(consisten_shape)
shape2 = _remove_preceding_ones(cur_op.output.shape)
if (cur_op.prim in start_prims and cur_op != start_op) or \
consisten_shape != cur_op.output.shape:
shape1 != shape2:
return True
return False
@ -1108,7 +1116,10 @@ class GraphSplitGpu(GraphSplitByPattern):
return False
return True
appected_areas = {"ReduceSum"} if reduce_fusion else {"TensorScatterAdd", "UnsortedSegmentSum"}
if reduce_fusion:
appected_areas = {"ReduceSum", "CSRReduceSum"}
else:
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
for a, _ in dom.out_relations.items():
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
@ -1206,12 +1217,30 @@ class GraphSplitGpu(GraphSplitByPattern):
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
def _link_csr(dom):
def _same_input(op1, op2):
return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
if arg_idx == -1:
return []
fuse_tensor = dom.dom_op().inputs[arg_idx]
for a, _ in dom.in_relations.items():
if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
_same_input(dom.dom_op(), a.dom_op())):
return [a], True
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
any([op.output in fuse_tensor for op in a.ops]):
return [a], True
return []
def _fuse_loop():
self.fuse(CommonPattern.reshape)
self.fuse(CommonPattern.assign)
self.fuse(CommonPattern.elemwise_depth)
self.fuse(CommonPattern.elemwise_width)
self.fuse(_broadcast_tot)
self.fuse(_link_csr)
self.fuse(CommonPattern.broadcast_depth)
self.fuse(CommonPattern.broadcast_width)
self.fuse(_reduce_depth)

View File

@ -229,6 +229,10 @@ class PrimLib:
'TensorScatterAdd': Prim(OPAQUE),
'Gather': Prim(OPAQUE),
'GatherNd': Prim(OPAQUE),
'CSRReduceSum': Prim(OPAQUE),
'CSRMul': Prim(ELEMWISE),
'CSRDiv': Prim(ELEMWISE),
'CSRGather': Prim(OPAQUE),
'UnsortedSegmentSum': Prim(OPAQUE),
'StandardNormal': Prim(OPAQUE),
'OneHot': Prim(OPAQUE),

View File

@ -311,7 +311,7 @@ def test_batch_csr_ops():
assert np.allclose(graph_res_elemwise[0].values.asnumpy(), expect3)
assert np.allclose(graph_res_elemwise[1].values.asnumpy(), expect4)
expect5 = np.array([1., 1.], dtype=np.float32)
expect5 = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32)
res_gather = test_ops_pynative_gather()
test_ops_graph_gather = ms_function(test_ops_pynative_gather)
graph_res_gather = test_ops_graph_gather()