Enhance GNN and support graph recompute.

This commit is contained in:
tronzhang 2021-12-10 15:01:08 +08:00
parent d1d516e668
commit 9bcb3eb1ef
16 changed files with 1373 additions and 266 deletions

View File

@ -861,60 +861,114 @@ class GraphSplitGpu(GraphSplitByPattern):
if not dom.dom_op().prim in gather_prims:
return None
def _count_target_prim(ops, target_list):
count = 0
for op in ops:
if op.prim in target_list:
count += 1
return count
def _reduce_exclude(op, axis_list):
""" Whether this operator should be excluded.
Excluding condition:
1. Reduce the last axis.
2. There are at least one same axis between reduce axes and axis_list.
def _bfs_visit(start_op, start_prims, end_prims, total_ops):
Args:
op (Operator): Target reduce operator.
axis_list (list): List to check whether it is intersected by reduce axis.
Returns:
Boolean. Whether this operator should be excluded.
"""
axis = op.attrs["reduce_axis"]
if isinstance(axis, int):
axis = [axis]
in_shape_len = len(op.inputs[0].shape)
for i, dim in enumerate(axis):
axis[i] = in_shape_len + dim if dim < 0 else dim
fix_axis = []
for ax in axis:
if op.inputs[0].shape[ax] == 1:
continue
fix_axis.append(ax)
return (in_shape_len - 1 in fix_axis) or bool(set(fix_axis) & set(axis_list))
def _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis):
consisten_shape = start_op.output.shape
visited = []
op_queue = [start_op]
def _early_stop(cur_op):
if cur_op in end_ops:
# If reduce last axis or reduce the gather axis, stop early for not fusion.
if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis):
return True
else:
if (cur_op.prim in start_prims and cur_op != start_op) or \
consisten_shape != cur_op.output.shape:
return True
return False
while op_queue:
tmp_queue = []
for op in op_queue:
if op in visited:
if op in visited or not op in total_ops:
continue
if op.prim in end_prims or not op in total_ops:
continue
if (op.prim in start_prims and op != start_op) or consisten_shape != op.output.shape:
if _early_stop(op):
return False
if op in end_ops:
continue
for to_op in op.output.to_ops:
tmp_queue.append(to_op)
visited.append(op)
op_queue = tmp_queue
return True
def _shape_consistent(start_prims, end_prims, source, target):
"""Check whether it is always shape consistent from source nodes to target nodes."""
total_ops = source.ops + target.ops
start_ops = []
for op in source.ops:
if op.prim in start_prims:
start_ops.append(op)
end_ops = []
for op in total_ops:
if op.prim in end_prims and not any([to_op in total_ops for to_op in op.output.to_ops]):
end_ops.append(op)
total_ops = source.ops + target.ops
for start_op in start_ops:
is_consistent = _bfs_visit(start_op, start_prims, end_prims, total_ops)
gather_axis = start_op.attrs.get("axis", None)
if gather_axis is None:
# For GatherNd
gather_axis = list(range(len(start_op.inputs[1].shape)))
elif isinstance(gather_axis, int):
gather_axis = [gather_axis]
is_consistent = _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis)
if not is_consistent:
return False
return True
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum", "ReduceSum"}
for a, _ in dom.out_relations.items():
if _shape_consistent(gather_prims, appected_areas, dom, a) and \
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
return [a], False
return None
def _broadcast_opaque(dom):
"""Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
def _same_input(op1, op2):
return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
if len(dom.ops) != 1:
return None
# Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
if arg_idx == -1 or len(dom.ops) != 1:
if arg_idx == -1:
return None
fuse_tensor = dom.dom_op().inputs[arg_idx]
for a, _ in dom.in_relations.items():
# Rule 1: Same type with at lease one same input.
if a.dom_op().prim == dom.dom_op().prim and _same_input(dom.dom_op(), a.dom_op()):
return [a], True
# Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
any([op.output in fuse_tensor for op in a.ops]):
return [a], True

View File

@ -17,6 +17,7 @@
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <utility>
#include <set>
@ -37,6 +38,7 @@
namespace mindspore::graphkernel {
namespace {
auto constexpr NUMBER_COND_FOR_FILTER_INPLACE = 2;
std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
MS_LOG(EXCEPTION) << "Only process for reduce sum!";
@ -87,10 +89,13 @@ bool HaveReduceInPredecessors(const AnfNodePtr &node) {
return false;
}
inline int64_t CalNewIndex(int64_t old_index, int64_t reduce_index) {
return old_index - (old_index > reduce_index ? 1 : 0);
inline int64_t CalNewIndex(int64_t old_index, const std::set<int64_t> &reduce_indexs) {
int64_t count =
std::count_if(reduce_indexs.begin(), reduce_indexs.end(), [old_index](int i) { return i < old_index; });
return old_index - count;
}
} // namespace
std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
auto processor = kernel::GetProcessorFromContext();
if (processor == kernel::Processor::AICORE) {
@ -102,6 +107,7 @@ std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
}
bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
atomic_add_infos_.clear();
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -111,32 +117,38 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
sub_graph->set_manager(mng_sub);
}
// Rule: Only one ReduceSum inside sub-graph.
auto CheckSuitableTarget = [&mng_sub](const AtomicAddInfo &atomic_add_info) {
// Target type should not fuse any other ops in out direction, which means it should be in output list.
return mng_sub->node_users()[atomic_add_info.atomic_add_node].size() <= 1;
};
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
AtomicAddInfo atomic_add_info;
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
size_t target_cnt = 0;
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
if (IsPrimitiveCNode(inputs[i], target_type_)) {
atomic_add_info_.atomic_add_node = inputs[i]->cast<CNodePtr>();
atomic_add_info_.reduce_real_output_index = i - 1;
target_cnt++;
atomic_add_info.atomic_add_node = inputs[i]->cast<CNodePtr>();
atomic_add_info.reduce_real_output_index = i - 1;
atomic_add_info.real_output_num = inputs.size() - 1;
// Target type should not fuse any other ops in out direction, which means it should be in output list.
if (!CheckSuitableTarget(atomic_add_info)) {
continue;
}
atomic_add_infos_.push_back(atomic_add_info);
}
}
if (target_cnt != 1) {
return false;
}
atomic_add_info_.real_output_num = inputs.size() - 1;
} else if (IsPrimitiveCNode(real_return_node, target_type_)) {
atomic_add_info_.atomic_add_node = real_return_node->cast<CNodePtr>();
atomic_add_info_.real_output_num = 1;
atomic_add_info.atomic_add_node = real_return_node->cast<CNodePtr>();
atomic_add_info.real_output_num = 1;
if (CheckSuitableTarget(atomic_add_info)) {
atomic_add_infos_.push_back(atomic_add_info);
}
} else {
return false;
}
// Rule: ReduceSum should not fuse any other ops in out direction, which means it should be in output list.
return (mng_sub->node_users()[atomic_add_info_.atomic_add_node].size() <= 1);
return !atomic_add_infos_.empty();
}
bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
@ -150,17 +162,17 @@ bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
// 3. No other ReduceSum as output ReduceSum's predecessors (reduce compile limitation).
// Rule 1.
if (!FindCandidate(anf_node)) {
if (!FindCandidate(anf_node) || atomic_add_infos_.size() > 1) {
return false;
}
// Rule 2.
if (!SuitableForAtomicAdd(atomic_add_info_.atomic_add_node)) {
if (!SuitableForAtomicAdd(atomic_add_infos_[0].atomic_add_node)) {
return false;
}
// Rule 3.
return !HaveReduceInPredecessors(atomic_add_info_.atomic_add_node);
return !HaveReduceInPredecessors(atomic_add_infos_[0].atomic_add_node);
}
bool AtomicAddChecker::Check(const AnfNodePtr &node) {
@ -228,8 +240,8 @@ bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
return false;
}
void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
bool bypass) {
void AtomicCleanInsertter::CorrectKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
// Change kernel build info.
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
@ -244,17 +256,31 @@ void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_no
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
std::vector<std::string> new_outputs_format;
std::vector<TypeId> new_outputs_type;
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
if (bypass && real_output_num_ > 1 && i == reduce_real_output_index_) {
continue;
}
new_outputs_format.push_back(origin_outputs_format[i]);
new_outputs_type.push_back(origin_outputs_type[i]);
std::set<size_t> reduce_real_indices;
for (auto &info : clean_infos) {
reduce_real_indices.insert(info.first.reduce_real_output_index);
}
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
if (clean_infos[0].first.real_output_num == reduce_real_indices.size()) {
new_outputs_format.push_back(origin_outputs_format[0]);
new_outputs_type.push_back(origin_outputs_type[0]);
} else {
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
if (reduce_real_indices.count(i) > 0) {
continue;
}
new_outputs_format.push_back(origin_outputs_format[i]);
new_outputs_type.push_back(origin_outputs_type[i]);
}
}
for (const auto &clean_info : clean_infos) {
auto &new_input = clean_info.second;
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
new_info_builder.SetInputsFormat(new_inputs_format);
@ -268,41 +294,55 @@ void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_no
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) {
// add inplaceassign
AnfNodePtr out_node;
void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &parameters_infos) {
// Add inplaceassign
AnfNodePtr inplace_out_node;
bool fake_out = false;
std::set<size_t> reduce_indices;
for (auto &info : parameters_infos) {
reduce_indices.insert(info.first.reduce_real_output_index + 1);
}
size_t replace_index = 0;
auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex);
if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) {
if (!IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple) ||
retrun_node->cast<CNodePtr>()->inputs().size() == parameters_infos.size() + 1) {
fake_out = true;
inplace_out_node = parameters_infos[0].first.atomic_add_node;
replace_index = parameters_infos[0].first.reduce_real_output_index + 1;
} else {
const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < outs.size(); ++i) {
if (i != reduce_real_output_index_ + 1) {
out_node = outs[i];
if (reduce_indices.count(i) == 0) {
inplace_out_node = outs[i];
replace_index = i;
break;
}
}
} else {
out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
fake_out = true;
}
auto inplace_assign_node =
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
SetNodeAttrSafely("fake_output", MakeValue(fake_out), inplace_assign_node);
for (const auto &[atomic_add_info, new_parameter] : parameters_infos) {
auto inplace_assign_node = CreateCNode(
{NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_info.atomic_add_node, inplace_out_node},
sub_graph,
{.format = GetFormat(inplace_out_node), .shape = GetShape(inplace_out_node), .type = GetType(inplace_out_node)});
SetNodeAttrSafely("fake_output", MakeValue(fake_out), inplace_assign_node);
inplace_out_node = inplace_assign_node;
}
CNodePtr new_out_node;
if (real_output_num_ > 2) {
// If the real output number is less than or equal to two, it's no need to filter the inplace one out:
// 1. Two real outputs. After inplacing, only one left and `Inplace` out will be that output.
// 2. One real output. After inplacing, there is no output left, use fake one.
if (parameters_infos[0].first.real_output_num > NUMBER_COND_FOR_FILTER_INPLACE) {
std::vector<AnfNodePtr> output_args = {NewValueNode(prim::kPrimMakeTuple)};
const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < outs.size(); ++i) {
if (i == reduce_real_output_index_ + 1) {
if (reduce_indices.count(i) > 0) {
continue;
} else if (i == replace_index) {
output_args.push_back(inplace_assign_node);
output_args.push_back(inplace_out_node);
} else {
output_args.push_back(outs[i]);
}
@ -310,31 +350,46 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra
// Set output for AnfGraph
new_out_node = sub_graph->NewCNode(output_args);
} else {
new_out_node = inplace_assign_node;
CNodePtr out_cnode = inplace_out_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_cnode);
new_out_node = out_cnode;
}
sub_graph->set_output(new_out_node);
}
void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) const {
// If there is only one output(ReduceSum), it should be a fake output with the same abstract with origin output.
if (real_output_num_ <= 1) {
void AtomicCleanInsertter::CorrectAbstract(
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &process_infos) const {
// If there is only one output, it should be a fake output with the same abstract with origin output.
if (process_infos[0].first.real_output_num <= 1) {
return;
}
std::set<size_t> reduce_real_indices;
for (auto &info : process_infos) {
reduce_real_indices.insert(info.first.reduce_real_output_index);
}
// Change abstract.
auto origin_out_spec = composite_node->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(origin_out_spec);
const auto &origin_out_specs = origin_out_spec->elements();
AbstractBasePtrList new_out_specs;
for (size_t i = 0; i < origin_out_specs.size(); ++i) {
if (i != reduce_real_output_index_) {
if (reduce_real_indices.count(i) == 0) {
new_out_specs.push_back(origin_out_specs[i]);
}
}
// If empty, there will be a fake out, so use the first target reduce information.
if (new_out_specs.empty()) {
new_out_specs.push_back(origin_out_specs[process_infos[0].first.reduce_real_output_index]);
}
composite_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_out_specs));
}
void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
void AtomicCleanInsertter::ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
@ -342,23 +397,27 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
sub_graph->set_manager(mng_sub);
}
// Add atomic attribute to reducesum node.
SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_node_);
// Add input
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) {
// Add atomic attribute to reducesum node.
SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.atomic_add_node);
// add parameter
auto parameter = sub_graph->add_parameter();
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
parameters_infos.emplace_back(atomic_add_info, parameter);
}
// add input
auto inputs = composite_node->cast<CNodePtr>()->inputs();
inputs.push_back(new_input);
std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(), std::back_inserter(inputs),
[](const std::pair<AtomicAddInfo, AnfNodePtr> &pair_item) { return pair_item.second; });
composite_node->cast<CNodePtr>()->set_inputs(inputs);
// add parameter
auto parameter = sub_graph->add_parameter();
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
CorrectAbstract(composite_node);
CorrectKernelBuildInfo(composite_node, new_input);
CorrectAbstract(composite_node, info_and_broadcast_to_nodes);
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
@ -366,19 +425,6 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
}
void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) const {
// Create depend node to hold execution order.
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
auto depend_cnode = main_graph->NewCNode(d_inputs);
depend_cnode->set_abstract(clean_node->abstract());
main_graph->AddNode(depend_cnode);
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(IntToSize(index), depend_cnode);
}
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph,
const CNodePtr &composite_node) const {
// Insert update_state_node, need mount a monad node.
@ -391,7 +437,8 @@ CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_grap
return update_state_cnode;
}
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
const KernelGraphPtr &main_graph, TypeId dst_type) {
std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
@ -399,7 +446,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
}
// Create zero value which will be broadcast to target shape.
auto format = GetFormat(atomic_add_node_);
auto format = GetFormat(atomic_add_info.atomic_add_node);
auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
ValueNodePtr value_node;
if (dtype == kNumberTypeFloat32) {
@ -418,18 +465,19 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
auto cast_node_inner =
CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
SetNodeAttrSafely(kAttrDstType, kFloat32, cast_node_inner);
SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
broadcast_input_node = cast_node_inner;
} else {
broadcast_input_node = value_node;
}
// Create broadcast basic op.
auto dst_shape_vec = GetShape(atomic_add_node_);
auto dst_shape_vec = GetShape(atomic_add_info.atomic_add_node);
AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
auto broadcast_to_node_inner = CreateCNode(
atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)});
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_node_)), broadcast_to_node_inner);
auto broadcast_to_node_inner =
CreateCNode(atomic_clean_inputs, new_sub_graph,
{.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_info.atomic_add_node)});
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_info.atomic_add_node)), broadcast_to_node_inner);
// Makeup sub-graph.
new_sub_graph->set_output(broadcast_to_node_inner);
@ -443,17 +491,26 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
return broadcast_to_composite_node;
}
std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
const AnfNodePtr &composite_node,
const FuncGraphManagerPtr &mng,
bool correct_index) const {
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes;
if (real_output_num_ <= 1) {
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> AtomicCleanInsertter::FindOriginCNodeUsers(
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes, const FuncGraphManagerPtr &mng,
bool correct_index) const {
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> reduce_user_nodes;
std::set<int64_t> real_indices;
std::map<size_t, AnfNodePtr> real_indices_and_clean_node;
for (auto &[info, clean] : info_and_broadcast_to_nodes) {
real_indices_and_clean_node.insert({info.reduce_real_output_index, clean});
real_indices.insert(SizeToLong(info.reduce_real_output_index));
}
if (info_and_broadcast_to_nodes[0].first.real_output_num <= 1) {
auto users = mng->node_users()[composite_node];
(void)std::transform(users.cbegin(), users.cend(), std::back_inserter(reduce_user_nodes),
[](const std::pair<AnfNodePtr, int> &pair) { return pair; });
for (const auto &[user, index] : users) {
reduce_user_nodes.emplace_back(user, index, info_and_broadcast_to_nodes[0].second);
}
} else {
std::vector<std::pair<AnfNodePtr, int> > getitem_user_nodes;
std::vector<std::tuple<AnfNodePtr, AnfNodePtr>> getitem_user_nodes;
auto users = mng->node_users()[composite_node];
for (const auto &node_index : users) {
const auto &user_node = node_index.first;
@ -466,47 +523,43 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs
auto value_node = value_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto item_idx = GetValue<int64_t>(value_node->value());
if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) {
getitem_user_nodes.push_back(node_index);
auto iter = real_indices_and_clean_node.find(IntToSize(item_idx));
if (iter != real_indices_and_clean_node.end()) {
getitem_user_nodes.push_back({node_index.first, iter->second});
} else if (correct_index) {
if (real_output_num_ > 2) {
// Recorrect other getitem index.
int64_t new_item_idx = CalNewIndex(item_idx, SizeToLong(reduce_real_output_index_));
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node,
NewValueNode(new_item_idx)};
auto new_out = main_graph->NewCNode(new_inputs);
new_out->set_abstract(get_item_cnode->abstract());
for (const auto &[user, index] : mng->node_users()[get_item_cnode]) {
auto user_cnode = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(IntToSize(index), new_out);
}
} else {
for (const auto &[user, index] : mng->node_users()[node_index.first]) {
auto user_cnode = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(IntToSize(index), composite_node);
}
// Recorrect other getitem index.
int64_t new_item_idx = CalNewIndex(item_idx, real_indices);
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node, NewValueNode(new_item_idx)};
auto new_out = main_graph->NewCNode(new_inputs);
new_out->set_abstract(get_item_cnode->abstract());
for (const auto &[user, index] : mng->node_users()[get_item_cnode]) {
auto user_cnode = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(IntToSize(index), new_out);
}
}
}
for (auto &pair : getitem_user_nodes) {
for (auto &[getitem_node, broadcast_to_node] : getitem_user_nodes) {
// Directory to find real user.
auto real_users = mng->node_users()[pair.first];
(void)reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end());
auto real_users = mng->node_users()[getitem_node];
std::transform(real_users.cbegin(), real_users.cend(), std::back_inserter(reduce_user_nodes),
[&broadcast_to_node](const std::pair<AnfNodePtr, int> &pair) {
return std::make_tuple(pair.first, pair.second, broadcast_to_node);
});
}
}
return reduce_user_nodes;
}
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node,
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
void AtomicCleanInsertter::ProcessOriginCNodeUser(
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
// 1. find users, change getitem index if needed.
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
FindOriginCNodeUsers(main_graph, composite_node, mng, true);
for (const auto &[user_node, index] : reduce_user_nodes) {
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> reduce_user_nodes =
FindOriginCNodeUsers(main_graph, composite_node, info_and_broadcast_to_nodes, mng, true);
for (const auto &[user_node, index, broadcast_to_node] : reduce_user_nodes) {
// 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect
// update_state_node, broadcat_node and load_node to keep order.
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node};
@ -519,50 +572,38 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
}
}
void AtomicCleanInsertter::UpdateAtomicAddInfo(const AtomicAddInfo &atomic_add_info) {
atomic_add_node_ = atomic_add_info.atomic_add_node;
reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
real_output_num_ = atomic_add_info.real_output_num;
}
void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos,
const FuncGraphManagerPtr &mng) {
auto origin_composite_node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_composite_node);
// Create broadcst node.
auto out_type = GetType(atomic_add_node_)->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(out_type);
auto broadcast_to_node = CreateAtomicCleanCompositeNode(main_graph, out_type->element()->type_id());
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_broadcast_to_nodes;
for (auto atomic_add_info : atomic_add_infos) {
auto out_type = GetType(atomic_add_info.atomic_add_node)->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(out_type);
auto broadcast_to_node =
CreateAtomicCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id());
info_and_broadcast_to_nodes.emplace_back(atomic_add_info, broadcast_to_node);
}
// Insert extra input(broadcast node output) to composite node, and make Reducesum inplaceassign to it.
// Note: if it's single output, this will increase total memory because of a fake out.
ProcessOriginCNode(origin_composite_node, broadcast_to_node);
ProcessOriginCNode(origin_composite_node, info_and_broadcast_to_nodes);
// Insert update_state_node to keep execution order.
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
// Replace origin ReduceSum's user with atomic clean output
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng);
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
<< ", clean node: " << broadcast_to_node->fullname_with_scope();
}
ProcessOriginCNodeUser(main_graph, origin_composite_node, info_and_broadcast_to_nodes, update_state_node, mng);
std::stringstream ss;
ss << "Target node: " << origin_composite_node->fullname_with_scope() << ", clean nodes: ";
for (auto iter : info_and_broadcast_to_nodes) {
ss << iter.second->fullname_with_scope() << ", ";
}
bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
const FuncGraphManagerPtr &mng) {
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
// node and user node. If reduce is Depend node, the origin node may be wrong!
return std::all_of(
reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
auto &user = user_info.first;
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) &&
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
return false;
} else {
return true;
}
});
MS_LOG(INFO) << ss.str();
}
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
@ -582,13 +623,12 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
auto topo_nodes = TopoSort(kernel_graph->get_return());
for (const auto &node : topo_nodes) {
if (!atomic_add_checker->Check(node) || !IsExistStructuralObstacle(kernel_graph, node, mng)) {
if (!atomic_add_checker->Check(node)) {
continue;
}
auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
InsertAtomicClean(kernel_graph, node, atomic_add_infos, mng);
changed = true;
auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
UpdateAtomicAddInfo(atomic_add_info);
InsertAtomicClean(kernel_graph, node, mng);
}
if (changed) {

View File

@ -39,13 +39,13 @@ class AtomicAddChecker {
static std::shared_ptr<AtomicAddChecker> Init();
bool Check(const AnfNodePtr &node);
AtomicAddInfo GetAtomicAddInfo() { return atomic_add_info_; }
std::vector<AtomicAddInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
protected:
virtual bool SuitableForAtomicAdd(const AnfNodePtr &node) { return false; }
virtual bool FindCandidate(const AnfNodePtr &anf_node);
virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
AtomicAddInfo atomic_add_info_;
std::vector<AtomicAddInfo> atomic_add_infos_;
PrimitivePtr target_type_{prim::kPrimReduceSum};
};
@ -74,31 +74,28 @@ class AtomicCleanInsertter : public opt::Pass {
bool Run(const FuncGraphPtr &func_graph) override;
protected:
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
bool bypass = true);
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
virtual CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index) const;
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos);
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes);
virtual CNodePtr CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
const KernelGraphPtr &main_graph, TypeId dst_type);
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) const;
void CorrectAbstract(const AnfNodePtr &composite_node) const;
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
void CorrectAbstract(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &process_infos) const;
void CreateInplaceAssignNodeAndCorrectReturn(
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &parameters_infos);
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node,
const FuncGraphManagerPtr &mng);
void UpdateAtomicAddInfo(const AtomicAddInfo &info);
CNodePtr atomic_add_node_{nullptr};
size_t reduce_real_output_index_{0};
size_t real_output_num_{0};
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng);
private:
std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
const AnfNodePtr &composite_node,
const FuncGraphManagerPtr &mng,
bool correct_index) const;
bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
const FuncGraphManagerPtr &mng);
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> FindOriginCNodeUsers(
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const FuncGraphManagerPtr &mng, bool correct_index) const;
};
using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>;
} // namespace mindspore::graphkernel

View File

@ -30,20 +30,74 @@
#include "backend/session/kernel_graph.h"
namespace mindspore::graphkernel {
void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
// Change kernel build info.
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
auto origin_processor = origin_kernel_build_info->processor();
std::vector<std::string> &new_inputs_format = origin_inputs_format;
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
std::vector<std::string> new_outputs_format;
std::vector<TypeId> new_outputs_type;
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
new_outputs_format.push_back(origin_outputs_format[i]);
new_outputs_type.push_back(origin_outputs_type[i]);
}
auto kernel_with_index = AnfAlgo::VisitKernel(clean_infos[0].second, 0);
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
new_info_builder.SetInputsFormat(new_inputs_format);
new_info_builder.SetInputsDeviceType(new_inputs_type);
new_info_builder.SetOutputsFormat(new_outputs_format);
new_info_builder.SetOutputsDeviceType(new_outputs_type);
new_info_builder.SetProcessor(origin_processor);
new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto new_selected_info = new_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
int index) const {
// Create depend node to hold execution order.
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
auto depend_cnode = main_graph->NewCNode(d_inputs);
depend_cnode->set_abstract(clean_node->abstract());
main_graph->AddNode(depend_cnode);
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(IntToSize(index), depend_cnode);
}
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter) const {
const AnfNodePtr &new_parameter,
const AtomicAddInfo &info) const {
// add inplaceassign
AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
AnfNodePtr out_node = info.atomic_add_node; // Use result data itself, and set attr "fake_out" true.
auto inplace_assign_node =
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, out_node, out_node}, sub_graph,
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_);
AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
return inplace_assign_node;
}
void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
void StitchAtomicCleanInsertter::ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
@ -51,6 +105,8 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
sub_graph->set_manager(mng_sub);
}
auto [atomic_add_info, new_input] = info_and_broadcast_to_nodes[0];
// add input
auto inputs = composite_node->cast<CNodePtr>()->inputs();
inputs.push_back(new_input);
@ -61,11 +117,12 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter);
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter, atomic_add_info);
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
// elimination.
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_);
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes =
FindInnerCNodeUsers(stitch_node_, atomic_add_info.atomic_add_node);
bool connected = false;
for (const auto &[user_node, index] : reduce_user_nodes) {
auto user_cnode = user_node->cast<CNodePtr>();
@ -79,7 +136,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
}
connected = true;
}
CorrectKernelBuildInfo(composite_node, new_input, false);
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
}
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
@ -105,8 +162,8 @@ std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNo
return inner_user_nodes;
}
bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
if (!AnfAlgo::IsGraphKernel(anf_node)) return false;
std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
if (!AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()};
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -116,12 +173,13 @@ bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node)
if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
atomic_add_node_ = n->cast<CNodePtr>();
AtomicAddInfo info;
info.atomic_add_node = n->cast<CNodePtr>();
stitch_node_ = anf_node;
return true;
return {true, info};
}
}
return false;
return {false, AtomicAddInfo()};
}
bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
@ -137,8 +195,9 @@ bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
auto topo_nodes = TopoSort(kernel_graph->get_return());
for (const auto &node : topo_nodes) {
// if stitch attr exists, add atomic clean op depends on the attr
if (IsStitchWithAtomic(node)) {
InsertAtomicClean(kernel_graph, node, mng);
auto [is_stitch, atomic_add_info] = IsStitchWithAtomic(node);
if (is_stitch) {
InsertAtomicClean(kernel_graph, node, {atomic_add_info}, mng);
changed = true;
}
}

View File

@ -33,11 +33,19 @@ class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
bool Run(const FuncGraphPtr &func_graph) override;
private:
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter) const;
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) override;
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
const AtomicAddInfo &info) const;
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) const;
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override;
bool IsStitchWithAtomic(const AnfNodePtr &anf_node);
void ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) override;
std::pair<bool, AtomicAddInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index) const;
AnfNodePtr stitch_node_{nullptr};
};

View File

@ -49,6 +49,8 @@
#include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
#include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
#include "backend/optimizer/graph_kernel/rewrite_output_shape.h"
#include "backend/optimizer/graph_kernel/graph_kernel_recompute.h"
#include "backend/optimizer/graph_kernel/reduce_fake_out_mem.h"
namespace mindspore::graphkernel {
using opt::CommonSubexpressionElimination;
@ -154,6 +156,14 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
auto pm = std::make_shared<GraphKernelPassManager>(4, "highlevelopt2");
auto &flags = GraphKernelFlags::GetInstance();
// Auto recompute according to local memory burst.
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0);
pm->AddPass(std::make_shared<GraphKernelRecompute>(), recompute_lv);
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), recompute_lv);
pm->AddPass(std::make_shared<MergeOutputForUpdateState>(), recompute_lv);
// Enable atomic add
pm->AddPass(std::make_shared<AtomicCleanInsertter>(), OptLevel_2, is_gpu || is_ascend);
@ -197,6 +207,9 @@ PassManagerPtr GraphKernelOptimizer::PostProcess() const {
pm->AddPass(std::make_shared<GetitemTuple>(), OptLevel_1);
pm->AddPass(std::make_shared<RewriteOutputShape>(), OptLevel_1);
// Reduce fake output memory.
pm->AddPass(std::make_shared<ReduceFakeOutMem>(), OptLevel_1);
// Add the new tensors to the kernel_graph
pm->AddPass(std::make_shared<BindValueToGraph>(), OptLevel_1);
return pm;

View File

@ -0,0 +1,636 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/graph_kernel_recompute.h"
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <tuple>
#include <utility>
#include <vector>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
namespace {
int64_t GetGetitemIndex(const AnfNodePtr &getitem) {
auto vnode = GetValueNode(getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
return GetValue<int64_t>(vnode);
}
AnfNodePtr GetOutput(const FuncGraphPtr &func_graph, size_t i) {
auto output_node = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_node);
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
if (i + 1 >= output_node->size()) {
MS_LOG(EXCEPTION) << i << " is out of range of MakeTuple's size " << output_node->size();
}
return output_node->input(i + 1);
} else {
if (i > 0) {
MS_LOG(EXCEPTION) << "the graph is single output but i is not 0. it's " << i;
}
return output_node->cast<AnfNodePtr>();
}
}
bool IsExclude(const AnfNodePtr &node) {
static std::vector<PrimitivePtr> excludes = {prim::kPrimReturn, prim::kPrimUpdateState, prim::kPrimLoad,
prim::kPrimMakeTuple, prim::kPrimDepend};
return std::any_of(excludes.begin(), excludes.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
enum class VisitType : char { FOLLOW, STOP };
using VisitFunc = std::function<VisitType(const AnfNodePtr &)>;
using NextFunc = std::function<AnfNodePtrList(const AnfNodePtr &)>;
using ProcessFunc = std::function<void(const AnfNodePtr &)>;
void Dfs(const AnfNodePtr &current, const VisitFunc &visit_func, const NextFunc &next_func,
const ProcessFunc &before_func, const ProcessFunc &after_func, std::set<AnfNodePtr> *visited) {
if (visited->count(current) > 0) {
return;
}
visited->insert(current);
if (visit_func(current) != VisitType::FOLLOW) {
return;
}
for (const auto &next : next_func(current)) {
before_func(next);
Dfs(next, visit_func, next_func, before_func, after_func, visited);
after_func(next);
}
}
OrderedMap<AnfNodePtr, AnfNodePtrList> CollectLinkPaths(const std::map<AnfNodePtr, MemorySize> &topo_indice,
const OrderedSet<AnfNodePtr> &direct_users,
MemorySize max_topo_user_index,
const FuncGraphManagerPtr &mng) {
std::stack<AnfNodePtr> cur_stack;
OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths;
auto TmpVisitFunc = [&topo_indice, max_topo_user_index](const AnfNodePtr &n) -> VisitType {
if (IsExclude(n)) {
return VisitType::STOP;
}
auto iter = topo_indice.find(n);
if (iter == topo_indice.end()) {
MS_LOG(EXCEPTION) << "Cannot find " << n->fullname_with_scope() << " in topo indices!";
}
if (iter->second > max_topo_user_index) {
return VisitType::STOP;
}
return VisitType::FOLLOW;
};
auto TmpNextFunc = [&mng](const AnfNodePtr &n) -> AnfNodePtrList {
auto users = mng->node_users()[n];
AnfNodePtrList nexts;
std::transform(users.cbegin(), users.cend(), std::back_inserter(nexts),
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
return nexts;
};
auto TmpBeforeFunc = [&link_paths, &cur_stack, &direct_users](const AnfNodePtr &next) -> void {
if (direct_users.count(next) == 0) {
return;
}
auto cur_node = cur_stack.top();
if (link_paths.find(cur_node) == link_paths.end()) {
link_paths.insert({cur_node, AnfNodePtrList()});
}
link_paths[cur_node].push_back(next);
cur_stack.push(next);
};
auto TmpAfterFunc = [&cur_stack, &direct_users](const AnfNodePtr &next) -> void {
if (direct_users.count(next) == 0) {
return;
}
cur_stack.push(next);
};
std::set<AnfNodePtr> visited;
for (auto user : direct_users) {
cur_stack.push(user);
Dfs(user, TmpVisitFunc, TmpNextFunc, TmpBeforeFunc, TmpAfterFunc, &visited);
cur_stack.pop();
}
return link_paths;
}
OrderedSet<AnfNodePtr> GetLongTermNodes(const AnfNodePtrList &nodes, const AnfNodePtr &end_node,
const std::map<AnfNodePtr, MemorySize> &topo_indices,
const FuncGraphManagerPtr &mng) {
OrderedSet<AnfNodePtr> long_term_nodes;
for (auto node : nodes) {
auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first;
// Parameter or value have long term tensors.
if (!utils::isa<CNodePtr>(real_node)) {
long_term_nodes.insert(node);
continue;
}
auto users = mng->node_users()[node];
if (std::any_of(users.cbegin(), users.cend(), [&topo_indices, &end_node](const std::pair<AnfNodePtr, int> &user) {
auto user_topo = topo_indices.find(user.first);
auto end_topo = topo_indices.find(end_node);
return user_topo->second >= end_topo->second;
})) {
long_term_nodes.insert(node);
}
}
return long_term_nodes;
}
/**
* @brief Remove real input which is not used and change the related graph parameters.
*
* @param func_graph Graph.
* @param inputs Real inputs for graph cnode.
*/
void ElimRedundantInputsAndGraphParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
const auto &ori_parameter = func_graph->parameters();
auto nodes = TopoSort(func_graph->get_return());
std::set<AnfNodePtr> used_param;
for (auto node : nodes) {
if (node->isa<Parameter>()) {
(void)used_param.insert(node);
}
}
if (used_param.size() == ori_parameter.size()) {
return;
}
AnfNodePtrList new_parameter, new_inputs;
for (size_t i = 0; i < ori_parameter.size(); ++i) {
if (used_param.count(ori_parameter[i])) {
new_parameter.push_back(ori_parameter[i]);
new_inputs.push_back((*inputs)[i]);
}
}
func_graph->set_parameters(new_parameter);
*inputs = std::move(new_inputs);
}
} // namespace
/**
* @brief Filter the input tensor(that live longer than end node) out and return valid inputs for memory calculation. \n
* If the topo indices of the input's user is at least one greater than end_node, \n
* it will retain when after end_node's execution.
*
* @param source_node
* @param end_node
* @param edge_pos
* @param mng
* @return AnfNodePtrList
*/
AnfNodePtrList AutoRecompute::Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
const FuncGraphManagerPtr &mng) {
auto source_cnode = source_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(source_cnode);
AnfNodePtrList node_inputs(source_cnode->inputs().begin() + 1, source_cnode->inputs().end());
OrderedSet<AnfNodePtr> long_term_inputs = GetLongTermNodes(node_inputs, end_node, topo_indice_, mng);
AnfNodePtrList check_inputs;
if (IsPrimitiveCNode(end_node->cast<CNodePtr>()->input(edge_pos), prim::kPrimTupleGetItem)) {
auto out_index = GetSourceLinkOutPos(end_node, edge_pos);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(source_node);
auto out = sub_graph->output();
if (!IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Filte input tensor error";
}
// Find subgraph's input according to edge node.
auto start_node = out->cast<CNodePtr>()->input(out_index + 1);
AnfNodePtrList sub_input_parameters;
std::queue<AnfNodePtr> node_q;
node_q.push(start_node);
while (!node_q.empty()) {
auto cur = node_q.front();
node_q.pop();
if (utils::isa<ParameterPtr>(cur)) {
sub_input_parameters.push_back(cur);
}
auto cur_cnode = cur->cast<CNodePtr>();
if (cur_cnode) {
for (size_t i = 1; i < cur_cnode->inputs().size(); ++i) {
node_q.push(cur_cnode->input(i));
}
}
}
// Filte input that user's topo index is great than source graph.
for (auto para : sub_input_parameters) {
for (size_t i = 0; i < sub_graph->parameters().size(); ++i) {
if (para == sub_graph->parameters()[i]) {
check_inputs.push_back(node_inputs[i]);
}
}
}
} else {
check_inputs = node_inputs;
}
AnfNodePtrList res;
for (auto input : check_inputs) {
if (long_term_inputs.count(input) == 0) {
res.push_back(input);
}
}
return res;
}
/**
* @brief Get valid users information by giving node, excluding TupleGetItem, Load and so on.
*/
std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> AutoRecompute::GetValidUsers(
const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
auto &user_map = mng->node_users();
auto users = user_map[node];
MemorySize max_topo_user_index = 0;
std::queue<std::pair<AnfNodePtr, int>> users_queue;
for (auto user_index : users) {
users_queue.push(user_index);
}
OrderedSet<AnfNodePtr> direct_users;
OutPosLinkMap user_edge_pos;
while (!users_queue.empty()) {
auto [user, index] = users_queue.front();
users_queue.pop();
if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
for (auto get_item_user : user_map[user]) {
users_queue.push(get_item_user);
}
continue;
} else if (IsExclude(user)) {
continue;
}
user_edge_pos[user].push_back(index);
direct_users.insert(user);
// Update maximum topo value.
if (topo_indice_[user] > max_topo_user_index) {
max_topo_user_index = topo_indice_[user];
}
}
return {direct_users, user_edge_pos, max_topo_user_index};
}
/**
* @brief Judege target node for recompute according to current node, and capture source node information when find \n
* target. There two type for tensor of the edge between source node and target node, example: \n
* source [Short-Term] A other \n
* \n
* [Long-Term] target \n
* For this example, \n
* 1. There are two path from source node to target node, and target is directly user for source node, \n
* so the tensor of their edge is a long-term tensor. \n
* 2. From source node to A, there is only one path, and A is directly user for source node, \n
* so the tensor of their edge is a short-term tensor.
*
* @param node Source node.
* @param mng Graph manager.
* @return OutPosLinkList Vector[Tuple(target node, input positions of target node for edge, edge type)].
*/
OutPosLinkList AutoRecompute::JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
auto [direct_users, user_edge_pos, max_topo_user_index] = GetValidUsers(node, mng);
OutPosLinkList target_link_infos;
OrderedSet<AnfNodePtr> long_term_users;
// If the number of direct users is less than 2, there will no side way to its user....
if (direct_users.size() >= 2) {
OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths =
CollectLinkPaths(topo_indice_, direct_users, max_topo_user_index, mng);
for (const auto &[source, paths] : link_paths) {
for (auto target : paths) {
if (target != source) {
target_link_infos.emplace_back(target, user_edge_pos[target], EdgeLifeTimeType::LongTerm);
long_term_users.insert(target);
}
}
}
}
// Direct users include long term users and short term users.
// If the short term user is graph kernel composite node, it may be absorb and reduce the local peak memory.
for (auto user : direct_users) {
if (long_term_users.count(user) == 0 && AnfAlgo::IsGraphKernel(user)) {
target_link_infos.emplace_back(user, user_edge_pos[user], EdgeLifeTimeType::ShortTerm);
}
}
return target_link_infos;
}
/**
* @brief Get position of edge tensor between source node and target node. \n
* For example, giving target node and edge position 0, will return 1: \n
* source node \n
* [0] [1] [2] <- output position \n
* | \n
* | \n
* / \n
* [0] [1] <- input position \n
* target node
*
* @param target Target node.
* @param pos The input position of target node for edge.
* @return int The output position of source node for edge.
*/
int AutoRecompute::GetSourceLinkOutPos(const AnfNodePtr &target, int pos) {
// If the input is get-item, than use get-item's index, otherwise zero.
auto cnode = target->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prenode = cnode->input(pos);
if (!IsPrimitiveCNode(prenode, prim::kPrimTupleGetItem)) {
return 0;
}
auto get_item_cnode = prenode->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(get_item_cnode);
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(value_input);
auto value_node = value_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return static_cast<int>(GetValue<int64_t>(value_node->value()));
}
MemorySize AutoRecompute::SelectThreshold(EdgeLifeTimeType type) {
MemorySize threshold = 0;
switch (type) {
case EdgeLifeTimeType::ShortTerm:
threshold = local_peak_threshold_;
break;
case EdgeLifeTimeType::LongTerm:
threshold =
local_peak_threshold_ == 0 ? lifetime_threshold_ : std::min(local_peak_threshold_, lifetime_threshold_);
break;
default:
MS_LOG(EXCEPTION) << "Unknown edge type!";
}
return threshold;
}
/**
* @brief Find recompute candidates(source node, target node, edge and its type) in func_graph. \n
* Result will be add to candidates_.
*
* @param func_graph
*/
void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
candidates_.clear();
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
// Do thing when threshold is default value 0.
if (SelectThreshold(EdgeLifeTimeType::ShortTerm) == 0 && SelectThreshold(EdgeLifeTimeType::LongTerm) == 0) {
return;
}
auto topo_nodes = TopoSort(func_graph->get_return());
// Topo indice is use to early stop in predecessor check.
for (size_t i = 0; i < topo_nodes.size(); ++i) {
topo_indice_.insert({topo_nodes[i], i});
}
// Candidate condition:
// 1. Judge current node can see its graph_kernel input with other input's backward path.
// 2. Memory variety between split out and origin more than threshold:
// `Size(gs_direct_outs_to_gt) - filter(gs_inputs, its) > threshold`.
for (auto node : topo_nodes) {
if (!AnfAlgo::IsGraphKernel(node)) {
continue;
}
auto target_graphs = JudegeTargetAndCaptureSource(node, mng);
if (target_graphs.empty()) {
continue;
}
OrderedMap<AnfNodePtr, OrderedMap<AnfNodePtr, std::pair<EdgeLifeTimeType, AnfNodePtrList>>> tmp_candidates;
for (auto [gt, gt_in_pos_vec, edge_life_time_type] : target_graphs) {
MemorySize threshold = SelectThreshold(edge_life_time_type);
for (auto gt_in_pos : gt_in_pos_vec) {
MemorySize out_tensor_size =
static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(node, GetSourceLinkOutPos(gt, gt_in_pos)));
MemorySize absorb_input_tensor_size = 0;
for (auto input : Filter(node, gt, gt_in_pos, mng)) {
absorb_input_tensor_size += static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(input, 0));
}
auto gt_cnode = gt->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(gt_cnode);
auto edge = gt_cnode->input(gt_in_pos);
if (out_tensor_size < absorb_input_tensor_size) {
continue;
}
if (out_tensor_size - absorb_input_tensor_size > threshold) {
if (tmp_candidates[node].find(gt) == tmp_candidates[node].end()) {
tmp_candidates[node][gt] = {edge_life_time_type, AnfNodePtrList{}};
}
// Only add getitem node as edge, if GS is single output node, there will be no edges.
if (IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) {
tmp_candidates[node][gt].second.push_back(edge);
}
}
}
}
// Delete duplicated link.
for (const auto &[source, target_and_link] : tmp_candidates) {
for (const auto &[target, link] : target_and_link) {
candidates_.push_back({source, target, link.first, link.second});
}
}
}
RecomputeCandidatesLog(candidates_);
}
void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candidates) {
MS_LOG(INFO) << "Recompute candidates: ";
for (auto candidate : candidates) {
MS_LOG(INFO) << " └─ GS: " << candidate.source_graph->fullname_with_scope();
MS_LOG(INFO) << " └─ GT: " << candidate.target_graph->fullname_with_scope();
for (auto edge : candidate.recompute_edges) {
MS_LOG(INFO) << " └─[Edge]─> " << edge->fullname_with_scope();
}
}
}
std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
const AnfNodePtrList &recompute_edges) {
MS_EXCEPTION_IF_NULL(source_graph);
auto gs = AnfAlgo::GetCNodeFuncGraphPtr(source_graph);
MS_EXCEPTION_IF_NULL(gs);
AnfNodePtrList inputs(source_graph->inputs().begin() + 1, source_graph->inputs().end());
auto new_funcgraph = BasicClone(gs);
auto output_node = new_funcgraph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_node);
if (!IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
return {new_funcgraph, inputs};
}
// remove outputs that not in recompute edges.
AnfNodePtrList new_outputs;
for (auto &edge : recompute_edges) {
auto idx = GetGetitemIndex(edge);
new_outputs.push_back(GetOutput(new_funcgraph, idx));
}
if (new_outputs.size() + 1 == output_node->size()) {
return {new_funcgraph, inputs};
}
new_outputs.insert(new_outputs.begin(), output_node->input(0));
auto new_output_node = new_funcgraph->NewCNode(new_outputs);
// use the old abstract, since the new_funcgraph will be deleted in later process.
new_output_node->set_abstract(output_node->abstract());
new_output_node->set_kernel_info(std::make_shared<device::KernelInfo>());
new_funcgraph->set_output(new_output_node);
ElimRedundantInputsAndGraphParameters(new_funcgraph, &inputs);
return {new_funcgraph, inputs};
}
void GraphKernelRecompute::LinkIntoTargetFuncGraph(const Candidate &candidate, const FuncGraphPtr &cloned_func,
const AnfNodePtrList &cloned_inputs) {
auto cloned_nodes = TopoSort(cloned_func->get_return());
auto gt = AnfAlgo::GetCNodeFuncGraphPtr(candidate.target_graph);
MS_EXCEPTION_IF_NULL(gt);
auto mng = gt->manager();
if (mng == nullptr) {
mng = Manage(gt, true);
gt->set_manager(mng);
}
// link the outputs to gt
auto gt_node = candidate.target_graph->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(gt_node);
AnfNodePtrList new_parameters;
AnfNodePtrList new_inputs;
auto &params = gt->parameters();
for (size_t i = 0; i < params.size(); i++) {
// if the parameter is a recompute edge, then links the param to the cloned_func's output.
auto iter = std::find(candidate.recompute_edges.begin(), candidate.recompute_edges.end(), gt_node->input(i + 1));
if (iter != candidate.recompute_edges.end()) {
auto out_index = iter - candidate.recompute_edges.begin();
mng->Replace(params[i], GetOutput(cloned_func, out_index));
} else {
new_parameters.push_back(params[i]);
new_inputs.push_back(gt_node->input(i + 1));
}
}
// add new parameters
auto &cloned_func_params = cloned_func->parameters();
for (size_t i = 0; i < cloned_func_params.size(); i++) {
auto iter = std::find(new_inputs.begin(), new_inputs.end(), cloned_inputs[i]);
if (iter != new_inputs.end()) {
auto idx = iter - new_inputs.begin();
cloned_func->manager()->Replace(cloned_func_params[i], new_parameters[idx]);
} else {
new_parameters.push_back(gt->add_parameter());
new_inputs.push_back(cloned_inputs[i]);
cloned_func->manager()->Replace(cloned_func_params[i], new_parameters.back());
}
}
// reset the func_graph for cloned_nodes.
for (auto &node : cloned_nodes) {
if (node->isa<CNode>()) {
node->set_func_graph(gt);
}
}
AnfNodePtrList new_node_inputs = {gt_node->input(0)};
new_node_inputs.insert(new_node_inputs.end(), new_inputs.begin(), new_inputs.end());
gt->set_parameters(new_parameters);
gt_node->set_inputs(new_node_inputs);
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(gt, &outputs);
gt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
SetNewKernelInfo(gt_node, gt, new_inputs, outputs);
mng->RemoveRoots();
mng->KeepRoots({gt});
}
void GraphKernelRecompute::Process(const Candidate &candidate) {
auto gs = AnfAlgo::GetCNodeFuncGraphPtr(candidate.source_graph);
MS_EXCEPTION_IF_NULL(gs);
if (candidate.recompute_edges.empty()) {
// single output, clone the whole source_graph.
return;
}
// AnfNodePtrList outputs;
auto [new_funcgraph, inputs] = CloneGraph(candidate.source_graph->cast<CNodePtr>(), candidate.recompute_edges);
auto mng = new_funcgraph->manager();
if (mng == nullptr) {
mng = Manage(new_funcgraph, true);
new_funcgraph->set_manager(mng);
}
if (AnfAlgo::IsGraphKernel(candidate.target_graph)) {
// the target graph is a GraphKernel, push the new_funcgraph into the target graph.
LinkIntoTargetFuncGraph(candidate, new_funcgraph, inputs);
} else {
// The target graph is not a GraphKernel, build the new_funcgraph to a CNode.
MS_LOG(WARNING) << "Target node " << candidate.target_graph->fullname_with_scope()
<< " is not a graph kernel node, cannot absort the link edge!";
return;
}
}
bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
int repeat_times = 2;
while (repeat_times--) {
candidates_ = AutoRecompute().Run(func_graph);
if (candidates_.empty()) {
return false;
}
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
for (auto &c : candidates_) {
if (!AnfAlgo::IsGraphKernel(c.target_graph)) {
continue;
}
std::ostringstream oss;
for (auto &e : c.recompute_edges) {
if (!IsPrimitiveCNode(e, prim::kPrimTupleGetItem)) {
MS_LOG(EXCEPTION) << "The edge should be GetItem but got " << e->fullname_with_scope();
}
oss << e->fullname_with_scope() << ", ";
}
MS_LOG(INFO) << "Clone " << c.source_graph->fullname_with_scope() << " to "
<< c.target_graph->fullname_with_scope() << ", edges [" << oss.str() << "]";
Process(c);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return true;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,110 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
#include <map>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "utils/context/graph_kernel_flags.h"
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
namespace mindspore::graphkernel {
/*
* Recompute some operator to reduce temporary memory peak.
*
* (a) (b) (a) (b)
* \ / | |
* Gs Gs1 |
* (c)/ | (c)| |
* / | Go |
* Go |(d) =========> depend
* \ |
* (e)\ | (e) Gs2
* \|
* Gt (d)
* Gt
*
* Where, split Gs to Gs1 and Gs2, and (x) means the temporary tensor.
* For left graph, the memory is (a+b) -> (c+d) -> (d+e)
* As for right graph, memory is (a+b) -> (b+c) -> (b+e) -> (d+e)
* If the (c+d) reach the threshold memory, and (b+c) or (b+e) is less than it,
* it may ease the memory burden.
*/
enum class EdgeLifeTimeType : char { ShortTerm, LongTerm };
inline std::ostream &operator<<(std::ostream &os, EdgeLifeTimeType type) {
std::map<EdgeLifeTimeType, std::string> out_str = {{EdgeLifeTimeType::ShortTerm, "[ShortTerm]"},
{EdgeLifeTimeType::LongTerm, "[LongTerm]"}};
return os << out_str[type];
}
using OutPosLinkList = std::vector<std::tuple<AnfNodePtr, std::vector<int>, EdgeLifeTimeType>>;
using OutPosLinkMap = std::map<AnfNodePtr, std::vector<int>>;
using MemorySize = int64_t;
struct Candidate {
AnfNodePtr source_graph;
AnfNodePtr target_graph;
EdgeLifeTimeType type;
AnfNodePtrList recompute_edges; // getitem list for recompute edges.
};
class AutoRecompute {
public:
std::vector<Candidate> Run(const FuncGraphPtr &func_graph) {
lifetime_threshold_ = GraphKernelFlags::GetInstance().recompute_increment_threshold;
local_peak_threshold_ = GraphKernelFlags::GetInstance().recompute_peak_threshold;
FindCandidates(func_graph);
return candidates_;
}
private:
OutPosLinkList JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
AnfNodePtrList Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
const FuncGraphManagerPtr &mng);
void FindCandidates(const FuncGraphPtr &func_graph);
int GetSourceLinkOutPos(const AnfNodePtr &target, int pos);
std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> GetValidUsers(const AnfNodePtr &node,
const FuncGraphManagerPtr &mng);
MemorySize SelectThreshold(EdgeLifeTimeType type);
std::map<AnfNodePtr, MemorySize> topo_indice_;
std::vector<Candidate> candidates_;
MemorySize lifetime_threshold_{0};
MemorySize local_peak_threshold_{0};
void RecomputeCandidatesLog(const std::vector<Candidate> &candidates);
};
class GraphKernelRecompute : public opt::Pass {
public:
GraphKernelRecompute() : Pass("graph_kernel_recompute") {}
~GraphKernelRecompute() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
void Process(const Candidate &candidate);
std::pair<FuncGraphPtr, AnfNodePtrList> CloneGraph(const CNodePtr &source_graph,
const AnfNodePtrList &recompute_edge);
void LinkIntoTargetFuncGraph(const Candidate &candidate, const FuncGraphPtr &cloned_func,
const AnfNodePtrList &cloned_inputs);
std::vector<Candidate> candidates_;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_

View File

@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/reduce_fake_out_mem.h"
#include <memory>
#include <set>
#include <vector>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore::graphkernel {
constexpr auto kFakeOut = "fake_output";
void ReduceFakeOutMem::ModifyAbstract(const AnfNodePtr &composite_node, const std::set<size_t> &fake_real_indices,
const AnfNodePtrList &output_list) {
if (fake_real_indices.empty()) {
return;
}
if (output_list.empty()) {
MS_LOG(EXCEPTION) << "Output size should not be zero while there is at least one fake output!";
}
std::vector<AbstractBasePtr> out_specs;
for (size_t i = 0; i < output_list.size(); ++i) {
if (fake_real_indices.count(i)) {
std::vector<int64_t> shape_vec_shape = {1};
AbstractBasePtr abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
out_specs.push_back(abstract);
} else {
out_specs.push_back(output_list[i]->abstract());
}
}
AbstractBasePtr out_spec;
if (output_list.size() > 1) {
out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
} else {
out_spec = output_list[0]->abstract();
}
composite_node->set_abstract(out_spec);
}
bool ReduceFakeOutMem::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (!AnfAlgo::IsGraphKernel(node)) {
continue;
}
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
AnfNodePtrList output_list;
kernel::GetFuncGraphOutputNodes(sub_graph, &output_list);
std::set<size_t> fake_real_indices;
for (size_t i = 0; i < output_list.size(); ++i) {
auto &out = output_list[i];
auto out_cnode = out->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_cnode);
if (AnfAlgo::HasNodeAttr(kFakeOut, out_cnode) && AnfAlgo::GetNodeAttr<bool>(out_cnode, kFakeOut)) {
fake_real_indices.insert(i);
}
}
if (fake_real_indices.empty()) {
continue;
}
ModifyAbstract(node, fake_real_indices, output_list);
changed = true;
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REDUCE_FAKE_OUT_MEM_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REDUCE_FAKE_OUT_MEM_H_
#include <memory>
#include <set>
#include "utils/context/graph_kernel_flags.h"
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
#include "ir/func_graph.h"
namespace mindspore::graphkernel {
/**
* @brief Reduce a fake output memory from origin memory size to 1.
*/
class ReduceFakeOutMem : public opt::Pass {
public:
ReduceFakeOutMem() : Pass("reduce_fake_output_memory") {}
~ReduceFakeOutMem() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
void ModifyAbstract(const AnfNodePtr &composite_node, const std::set<size_t> &fake_real_indices,
const AnfNodePtrList &output_list);
};
using ReduceFakeOutMemPtr = std::shared_ptr<ReduceFakeOutMem>;
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REDUCE_FAKE_OUT_MEM_H_

View File

@ -16,6 +16,7 @@
#include "backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.h"
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "base/core_ops.h"
@ -41,16 +42,20 @@ class TsaChecker : public AtomicAddChecker {
return false;
}
auto tsa_cnode = atomic_add_info_.atomic_add_node;
if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
return false;
for (auto atomic_add_info : atomic_add_infos_) {
auto tsa_cnode = atomic_add_info.atomic_add_node;
if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
return false;
}
}
return true;
}
};
AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const AnfNodePtr &node) {
std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &main_graph,
const CNodePtr &tsa_node,
const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
@ -60,13 +65,14 @@ AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelG
sub_graph->set_manager(mng_sub);
}
auto first_input = atomic_add_node_->input(1)->cast<ParameterPtr>();
auto first_input = tsa_node->input(1)->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(first_input);
auto parameters = sub_graph->parameters();
bool hit = false;
size_t tsa_first_input_index = 0;
for (size_t i = 0; i < parameters.size(); ++i) {
if (parameters[i] == first_input) {
tsa_first_input_index_ = i;
tsa_first_input_index = i;
hit = true;
break;
}
@ -75,25 +81,31 @@ AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelG
MS_LOG(EXCEPTION) << "Cannot find tensor scatter add first input in sub-graph parameters!";
}
return cnode->input(tsa_first_input_index_ + 1); // CNode input have a primitive, so add 1.
return {cnode->input(tsa_first_input_index + 1), tsa_first_input_index}; // CNode input have a primitive, so add 1.
}
AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &main_graph, const AnfNodePtr &node) {
std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode(
const KernelGraphPtr &main_graph, const AtomicAddInfo &atomic_add_info, const AnfNodePtr &node) {
auto mng = main_graph->manager();
if (mng == nullptr) {
mng = Manage(main_graph, true);
main_graph->set_manager(mng);
}
// find first input of tsa
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, node);
auto users = mng->node_users()[tsa_first_input];
if (users.size() == 1 && !(utils::isa<ValueNodePtr>(tsa_first_input) || utils::isa<ParameterPtr>(tsa_first_input))) {
// Find first input of tsa
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.atomic_add_node, node);
auto users = mng->node_users()[tsa_first_input.first];
if (users.size() == 1 &&
!(utils::isa<ValueNodePtr>(tsa_first_input.first) || utils::isa<ParameterPtr>(tsa_first_input.first))) {
// If current composite node is only user, and first input is not Parameter or Tensor Value, then use itself.
return tsa_first_input;
}
// Create a copy of first input to atomic add to.
// Create composite op's sub-graph.
auto new_sub_graph = std::make_shared<FuncGraph>();
auto parameter = new_sub_graph->add_parameter();
auto kernel_with_index = AnfAlgo::VisitKernel(tsa_first_input, 0);
auto kernel_with_index = AnfAlgo::VisitKernel(tsa_first_input.first, 0);
parameter->set_abstract(GetOutputAbstract(kernel_with_index.first, kernel_with_index.second));
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
std::string parameter_format;
@ -123,18 +135,18 @@ AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &
// Makeup sub-graph.
new_sub_graph->set_output(identity_node);
auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input});
new_composite_node->set_abstract(identity_node->abstract());
SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node});
auto new_copy_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input.first});
new_copy_composite_node->set_abstract(identity_node->abstract());
SetNewKernelInfo(new_copy_composite_node, new_sub_graph, {tsa_first_input.first}, {identity_node});
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
return new_composite_node;
return {new_copy_composite_node, tsa_first_input.second};
}
void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const AnfNodePtr &modified_input, bool) {
void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos) {
// Change kernel build info with modify input
auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
@ -149,19 +161,29 @@ void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composi
std::vector<TypeId> &modified_inputs_type = origin_inputs_type;
std::vector<std::string> new_outputs_format;
std::vector<TypeId> new_outputs_type;
std::set<size_t> reduce_real_indices;
for (auto &info : outer_infos) {
reduce_real_indices.insert(std::get<0>(info).reduce_real_output_index);
}
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
if (real_output_num_ > 1 && i == reduce_real_output_index_) {
if (std::get<0>(outer_infos[0]).real_output_num > 1 && reduce_real_indices.count(i) > 0) {
continue;
}
new_outputs_format.push_back(origin_outputs_format[i]);
new_outputs_type.push_back(origin_outputs_type[i]);
}
auto kernel_with_index = AnfAlgo::VisitKernel(modified_input, 0);
modified_inputs_format[tsa_first_input_index_] =
AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
modified_inputs_type[tsa_first_input_index_] =
AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
for (const auto &outer_info : outer_infos) {
auto &modified_input = std::get<1>(outer_info);
auto tsa_first_input_index = std::get<2>(outer_info);
auto kernel_with_index = AnfAlgo::VisitKernel(modified_input, 0);
modified_inputs_format[tsa_first_input_index] =
AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
modified_inputs_type[tsa_first_input_index] =
AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
new_info_builder.SetInputsFormat(modified_inputs_format);
@ -175,7 +197,8 @@ void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composi
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &outter_node) {
void TsaAtomicAddToFirstTensor::ProcessOriginCNode(
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
@ -183,12 +206,20 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n
sub_graph->set_manager(mng_sub);
}
// modify input
composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index_ + 1, outter_node);
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, sub_graph->parameters()[tsa_first_input_index_]);
// Modify input
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_tsa_outers;
for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) {
composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index + 1, outer_node);
auto parameter = sub_graph->parameters()[tsa_first_input_index];
parameters_infos.emplace_back(atomic_add_info, parameter);
info_and_tsa_outers.emplace_back(atomic_add_info, outer_node);
}
CorrectAbstract(composite_node);
CorrectKernelBuildInfo(composite_node, outter_node);
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
CorrectAbstract(composite_node, info_and_tsa_outers);
CorrectKernelBuildInfo(composite_node, outer_nodes);
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name =
@ -198,24 +229,34 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n
}
void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos,
const FuncGraphManagerPtr &mng) {
auto origin_composite_node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_composite_node);
// Create identity node.
auto outter_node = ProcessTsaFirstNode(main_graph, anf_node);
// Create identity node. // Create broadcst node.
std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_outer_nodes;
for (auto atomic_add_info : atomic_add_infos) {
auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node);
info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second);
info_and_outer_nodes.emplace_back(atomic_add_info, outer.first);
}
// Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplaceassign to it.
// Note: if it's single output, this will increase total memory because of a fake out.
ProcessOriginCNode(origin_composite_node, outter_node);
ProcessOriginCNode(origin_composite_node, info_and_outer_nodes_with_index);
// Insert update_state_node to keep execution order.
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
// Replace origin ReduceSum's user with atomic clean output
ProcessOriginCNodeUser(main_graph, origin_composite_node, outter_node, update_state_node, mng);
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
<< ", outer node: " << outter_node->fullname_with_scope();
ProcessOriginCNodeUser(main_graph, origin_composite_node, info_and_outer_nodes, update_state_node, mng);
std::stringstream ss;
ss << "Target node: " << origin_composite_node->fullname_with_scope() << ", outer nodes: ";
for (auto iter : info_and_outer_nodes) {
ss << iter.second->fullname_with_scope() << ", ";
}
}
bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
@ -239,11 +280,8 @@ bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
if (!atomic_add_checker->Check(node)) {
continue;
}
auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
atomic_add_node_ = atomic_add_info.atomic_add_node;
reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
real_output_num_ = atomic_add_info.real_output_num;
ProcessTsa(kernel_graph, node, mng);
auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
ProcessTsa(kernel_graph, node, atomic_add_infos, mng);
changed = true;
}

View File

@ -48,12 +48,17 @@ class TsaAtomicAddToFirstTensor : public AtomicCleanInsertter {
bool Run(const FuncGraphPtr &func_graph) override;
private:
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override;
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
bool bypass = true) override;
void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
AnfNodePtr ProcessTsaFirstNode(const KernelGraphPtr &main_graph, const AnfNodePtr &node);
AnfNodePtr FindTsaFirstRealInputInGraph(const KernelGraphPtr &main_graph, const AnfNodePtr &node);
void ProcessOriginCNode(const AnfNodePtr &composite_node,
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes);
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos);
void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
std::pair<AnfNodePtr, size_t> GetOrCreateNewTsaFirstNode(const KernelGraphPtr &main_graph,
const AtomicAddInfo &atomic_add_info,
const AnfNodePtr &node);
std::pair<AnfNodePtr, size_t> FindTsaFirstRealInputInGraph(const KernelGraphPtr &main_graph, const CNodePtr &tsa_node,
const AnfNodePtr &node);
size_t tsa_first_input_index_{0}; // sub-graph parameter index.
};

View File

@ -133,6 +133,7 @@ bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
}
}
if (changed) {
UpdateMng(mng, func_graph);
std::make_shared<SpreadUpdateState>()->Run(func_graph);
std::make_shared<EliminateHangingOutput>()->Run(func_graph);
}

View File

@ -55,9 +55,9 @@ bool UssAtomicAdd::Run(const FuncGraphPtr &func_graph) {
if (!atomic_add_checker->Check(node)) {
continue;
}
auto info = atomic_add_checker->GetAtomicAddInfo();
UpdateAtomicAddInfo(info);
InsertAtomicClean(kernel_graph, node, mng);
auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
InsertAtomicClean(kernel_graph, node, atomic_add_infos, mng);
changed = true;
}

View File

@ -230,6 +230,8 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("online_tuning", &online_tuning);
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_MAX);
reg.AddFlag("parallel_ops_level", &parallel_ops_level);
reg.AddFlag("recompute_increment_threshold", &recompute_increment_threshold);
reg.AddFlag("recompute_peak_threshold", &recompute_peak_threshold);
// String flags
reg.AddFlag("repository_path", &repository_path);
@ -261,6 +263,8 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["fusion_ops_level"] = fusion_ops_level;
json["parallel_ops_level"] = parallel_ops_level;
json["online_tuning"] = online_tuning;
json["recompute_increment_threshold"] = recompute_increment_threshold;
json["recompute_peak_threshold"] = recompute_peak_threshold;
json["repository_path"] = repository_path;

View File

@ -127,6 +127,16 @@ class GraphKernelFlags {
*/
unsigned int online_tuning{0};
/**
* Threshold for detection of recopute's memory increment case.
*/
int64_t recompute_increment_threshold{0};
/**
* Threshold for detection of recopute's memory peak case.
*/
int64_t recompute_peak_threshold{0};
/**
* AKG's operator repository file path.
*/