forked from mindspore-Ecosystem/mindspore
Enhance GNN and support graph recompute.
This commit is contained in:
parent
d1d516e668
commit
9bcb3eb1ef
|
@ -861,60 +861,114 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
if not dom.dom_op().prim in gather_prims:
|
||||
return None
|
||||
|
||||
def _count_target_prim(ops, target_list):
|
||||
count = 0
|
||||
for op in ops:
|
||||
if op.prim in target_list:
|
||||
count += 1
|
||||
return count
|
||||
def _reduce_exclude(op, axis_list):
|
||||
""" Whether this operator should be excluded.
|
||||
Excluding condition:
|
||||
1. Reduce the last axis.
|
||||
2. There are at least one same axis between reduce axes and axis_list.
|
||||
|
||||
def _bfs_visit(start_op, start_prims, end_prims, total_ops):
|
||||
Args:
|
||||
op (Operator): Target reduce operator.
|
||||
axis_list (list): List to check whether it is intersected by reduce axis.
|
||||
Returns:
|
||||
Boolean. Whether this operator should be excluded.
|
||||
"""
|
||||
axis = op.attrs["reduce_axis"]
|
||||
if isinstance(axis, int):
|
||||
axis = [axis]
|
||||
in_shape_len = len(op.inputs[0].shape)
|
||||
for i, dim in enumerate(axis):
|
||||
axis[i] = in_shape_len + dim if dim < 0 else dim
|
||||
fix_axis = []
|
||||
for ax in axis:
|
||||
if op.inputs[0].shape[ax] == 1:
|
||||
continue
|
||||
fix_axis.append(ax)
|
||||
return (in_shape_len - 1 in fix_axis) or bool(set(fix_axis) & set(axis_list))
|
||||
|
||||
def _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis):
|
||||
consisten_shape = start_op.output.shape
|
||||
visited = []
|
||||
op_queue = [start_op]
|
||||
|
||||
def _early_stop(cur_op):
|
||||
if cur_op in end_ops:
|
||||
# If reduce last axis or reduce the gather axis, stop early for not fusion.
|
||||
if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis):
|
||||
return True
|
||||
else:
|
||||
if (cur_op.prim in start_prims and cur_op != start_op) or \
|
||||
consisten_shape != cur_op.output.shape:
|
||||
return True
|
||||
return False
|
||||
|
||||
while op_queue:
|
||||
tmp_queue = []
|
||||
for op in op_queue:
|
||||
if op in visited:
|
||||
if op in visited or not op in total_ops:
|
||||
continue
|
||||
if op.prim in end_prims or not op in total_ops:
|
||||
continue
|
||||
if (op.prim in start_prims and op != start_op) or consisten_shape != op.output.shape:
|
||||
if _early_stop(op):
|
||||
return False
|
||||
if op in end_ops:
|
||||
continue
|
||||
for to_op in op.output.to_ops:
|
||||
tmp_queue.append(to_op)
|
||||
visited.append(op)
|
||||
op_queue = tmp_queue
|
||||
|
||||
return True
|
||||
|
||||
def _shape_consistent(start_prims, end_prims, source, target):
|
||||
"""Check whether it is always shape consistent from source nodes to target nodes."""
|
||||
total_ops = source.ops + target.ops
|
||||
|
||||
start_ops = []
|
||||
for op in source.ops:
|
||||
if op.prim in start_prims:
|
||||
start_ops.append(op)
|
||||
end_ops = []
|
||||
for op in total_ops:
|
||||
if op.prim in end_prims and not any([to_op in total_ops for to_op in op.output.to_ops]):
|
||||
end_ops.append(op)
|
||||
|
||||
total_ops = source.ops + target.ops
|
||||
for start_op in start_ops:
|
||||
is_consistent = _bfs_visit(start_op, start_prims, end_prims, total_ops)
|
||||
gather_axis = start_op.attrs.get("axis", None)
|
||||
if gather_axis is None:
|
||||
# For GatherNd
|
||||
gather_axis = list(range(len(start_op.inputs[1].shape)))
|
||||
elif isinstance(gather_axis, int):
|
||||
gather_axis = [gather_axis]
|
||||
|
||||
is_consistent = _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis)
|
||||
if not is_consistent:
|
||||
return False
|
||||
return True
|
||||
|
||||
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
|
||||
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum", "ReduceSum"}
|
||||
for a, _ in dom.out_relations.items():
|
||||
if _shape_consistent(gather_prims, appected_areas, dom, a) and \
|
||||
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
|
||||
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
|
||||
return [a], False
|
||||
return None
|
||||
|
||||
def _broadcast_opaque(dom):
|
||||
"""Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
|
||||
def _same_input(op1, op2):
|
||||
return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
|
||||
|
||||
if len(dom.ops) != 1:
|
||||
return None
|
||||
|
||||
# Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
|
||||
fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
|
||||
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
|
||||
if arg_idx == -1 or len(dom.ops) != 1:
|
||||
if arg_idx == -1:
|
||||
return None
|
||||
fuse_tensor = dom.dom_op().inputs[arg_idx]
|
||||
|
||||
for a, _ in dom.in_relations.items():
|
||||
# Rule 1: Same type with at lease one same input.
|
||||
if a.dom_op().prim == dom.dom_op().prim and _same_input(dom.dom_op(), a.dom_op()):
|
||||
return [a], True
|
||||
# Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
|
||||
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
|
||||
any([op.output in fuse_tensor for op in a.ops]):
|
||||
return [a], True
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
|
@ -37,6 +38,7 @@
|
|||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
auto constexpr NUMBER_COND_FOR_FILTER_INPLACE = 2;
|
||||
std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
|
||||
MS_LOG(EXCEPTION) << "Only process for reduce sum!";
|
||||
|
@ -87,10 +89,13 @@ bool HaveReduceInPredecessors(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
inline int64_t CalNewIndex(int64_t old_index, int64_t reduce_index) {
|
||||
return old_index - (old_index > reduce_index ? 1 : 0);
|
||||
inline int64_t CalNewIndex(int64_t old_index, const std::set<int64_t> &reduce_indexs) {
|
||||
int64_t count =
|
||||
std::count_if(reduce_indexs.begin(), reduce_indexs.end(), [old_index](int i) { return i < old_index; });
|
||||
return old_index - count;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
|
||||
auto processor = kernel::GetProcessorFromContext();
|
||||
if (processor == kernel::Processor::AICORE) {
|
||||
|
@ -102,6 +107,7 @@ std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
|
|||
}
|
||||
|
||||
bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
|
||||
atomic_add_infos_.clear();
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
|
@ -111,32 +117,38 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
|
|||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
// Rule: Only one ReduceSum inside sub-graph.
|
||||
auto CheckSuitableTarget = [&mng_sub](const AtomicAddInfo &atomic_add_info) {
|
||||
// Target type should not fuse any other ops in out direction, which means it should be in output list.
|
||||
return mng_sub->node_users()[atomic_add_info.atomic_add_node].size() <= 1;
|
||||
};
|
||||
|
||||
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
|
||||
AtomicAddInfo atomic_add_info;
|
||||
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
|
||||
size_t target_cnt = 0;
|
||||
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (IsPrimitiveCNode(inputs[i], target_type_)) {
|
||||
atomic_add_info_.atomic_add_node = inputs[i]->cast<CNodePtr>();
|
||||
atomic_add_info_.reduce_real_output_index = i - 1;
|
||||
target_cnt++;
|
||||
atomic_add_info.atomic_add_node = inputs[i]->cast<CNodePtr>();
|
||||
atomic_add_info.reduce_real_output_index = i - 1;
|
||||
atomic_add_info.real_output_num = inputs.size() - 1;
|
||||
// Target type should not fuse any other ops in out direction, which means it should be in output list.
|
||||
if (!CheckSuitableTarget(atomic_add_info)) {
|
||||
continue;
|
||||
}
|
||||
atomic_add_infos_.push_back(atomic_add_info);
|
||||
}
|
||||
}
|
||||
|
||||
if (target_cnt != 1) {
|
||||
return false;
|
||||
}
|
||||
atomic_add_info_.real_output_num = inputs.size() - 1;
|
||||
} else if (IsPrimitiveCNode(real_return_node, target_type_)) {
|
||||
atomic_add_info_.atomic_add_node = real_return_node->cast<CNodePtr>();
|
||||
atomic_add_info_.real_output_num = 1;
|
||||
atomic_add_info.atomic_add_node = real_return_node->cast<CNodePtr>();
|
||||
atomic_add_info.real_output_num = 1;
|
||||
if (CheckSuitableTarget(atomic_add_info)) {
|
||||
atomic_add_infos_.push_back(atomic_add_info);
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rule: ReduceSum should not fuse any other ops in out direction, which means it should be in output list.
|
||||
return (mng_sub->node_users()[atomic_add_info_.atomic_add_node].size() <= 1);
|
||||
return !atomic_add_infos_.empty();
|
||||
}
|
||||
|
||||
bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
|
||||
|
@ -150,17 +162,17 @@ bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
|
|||
// 3. No other ReduceSum as output ReduceSum's predecessors (reduce compile limitation).
|
||||
|
||||
// Rule 1.
|
||||
if (!FindCandidate(anf_node)) {
|
||||
if (!FindCandidate(anf_node) || atomic_add_infos_.size() > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rule 2.
|
||||
if (!SuitableForAtomicAdd(atomic_add_info_.atomic_add_node)) {
|
||||
if (!SuitableForAtomicAdd(atomic_add_infos_[0].atomic_add_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rule 3.
|
||||
return !HaveReduceInPredecessors(atomic_add_info_.atomic_add_node);
|
||||
return !HaveReduceInPredecessors(atomic_add_infos_[0].atomic_add_node);
|
||||
}
|
||||
|
||||
bool AtomicAddChecker::Check(const AnfNodePtr &node) {
|
||||
|
@ -228,8 +240,8 @@ bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
|
||||
bool bypass) {
|
||||
void AtomicCleanInsertter::CorrectKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
|
||||
// Change kernel build info.
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -244,17 +256,31 @@ void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_no
|
|||
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
|
||||
std::vector<std::string> new_outputs_format;
|
||||
std::vector<TypeId> new_outputs_type;
|
||||
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
|
||||
if (bypass && real_output_num_ > 1 && i == reduce_real_output_index_) {
|
||||
continue;
|
||||
}
|
||||
new_outputs_format.push_back(origin_outputs_format[i]);
|
||||
new_outputs_type.push_back(origin_outputs_type[i]);
|
||||
|
||||
std::set<size_t> reduce_real_indices;
|
||||
for (auto &info : clean_infos) {
|
||||
reduce_real_indices.insert(info.first.reduce_real_output_index);
|
||||
}
|
||||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
if (clean_infos[0].first.real_output_num == reduce_real_indices.size()) {
|
||||
new_outputs_format.push_back(origin_outputs_format[0]);
|
||||
new_outputs_type.push_back(origin_outputs_type[0]);
|
||||
} else {
|
||||
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
|
||||
if (reduce_real_indices.count(i) > 0) {
|
||||
continue;
|
||||
}
|
||||
new_outputs_format.push_back(origin_outputs_format[i]);
|
||||
new_outputs_type.push_back(origin_outputs_type[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &clean_info : clean_infos) {
|
||||
auto &new_input = clean_info.second;
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
|
||||
new_info_builder.SetInputsFormat(new_inputs_format);
|
||||
|
@ -268,41 +294,55 @@ void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_no
|
|||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
|
||||
const AnfNodePtr &new_parameter) {
|
||||
// add inplaceassign
|
||||
AnfNodePtr out_node;
|
||||
void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(
|
||||
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> ¶meters_infos) {
|
||||
// Add inplaceassign
|
||||
AnfNodePtr inplace_out_node;
|
||||
bool fake_out = false;
|
||||
|
||||
std::set<size_t> reduce_indices;
|
||||
for (auto &info : parameters_infos) {
|
||||
reduce_indices.insert(info.first.reduce_real_output_index + 1);
|
||||
}
|
||||
size_t replace_index = 0;
|
||||
auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex);
|
||||
if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) {
|
||||
if (!IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple) ||
|
||||
retrun_node->cast<CNodePtr>()->inputs().size() == parameters_infos.size() + 1) {
|
||||
fake_out = true;
|
||||
inplace_out_node = parameters_infos[0].first.atomic_add_node;
|
||||
replace_index = parameters_infos[0].first.reduce_real_output_index + 1;
|
||||
} else {
|
||||
const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < outs.size(); ++i) {
|
||||
if (i != reduce_real_output_index_ + 1) {
|
||||
out_node = outs[i];
|
||||
if (reduce_indices.count(i) == 0) {
|
||||
inplace_out_node = outs[i];
|
||||
replace_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
|
||||
fake_out = true;
|
||||
}
|
||||
|
||||
auto inplace_assign_node =
|
||||
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
|
||||
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
|
||||
SetNodeAttrSafely("fake_output", MakeValue(fake_out), inplace_assign_node);
|
||||
for (const auto &[atomic_add_info, new_parameter] : parameters_infos) {
|
||||
auto inplace_assign_node = CreateCNode(
|
||||
{NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_info.atomic_add_node, inplace_out_node},
|
||||
sub_graph,
|
||||
{.format = GetFormat(inplace_out_node), .shape = GetShape(inplace_out_node), .type = GetType(inplace_out_node)});
|
||||
SetNodeAttrSafely("fake_output", MakeValue(fake_out), inplace_assign_node);
|
||||
inplace_out_node = inplace_assign_node;
|
||||
}
|
||||
|
||||
CNodePtr new_out_node;
|
||||
if (real_output_num_ > 2) {
|
||||
// If the real output number is less than or equal to two, it's no need to filter the inplace one out:
|
||||
// 1. Two real outputs. After inplacing, only one left and `Inplace` out will be that output.
|
||||
// 2. One real output. After inplacing, there is no output left, use fake one.
|
||||
if (parameters_infos[0].first.real_output_num > NUMBER_COND_FOR_FILTER_INPLACE) {
|
||||
std::vector<AnfNodePtr> output_args = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < outs.size(); ++i) {
|
||||
if (i == reduce_real_output_index_ + 1) {
|
||||
if (reduce_indices.count(i) > 0) {
|
||||
continue;
|
||||
} else if (i == replace_index) {
|
||||
output_args.push_back(inplace_assign_node);
|
||||
output_args.push_back(inplace_out_node);
|
||||
} else {
|
||||
output_args.push_back(outs[i]);
|
||||
}
|
||||
|
@ -310,31 +350,46 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra
|
|||
// Set output for AnfGraph
|
||||
new_out_node = sub_graph->NewCNode(output_args);
|
||||
} else {
|
||||
new_out_node = inplace_assign_node;
|
||||
CNodePtr out_cnode = inplace_out_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
new_out_node = out_cnode;
|
||||
}
|
||||
sub_graph->set_output(new_out_node);
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::CorrectAbstract(const AnfNodePtr &composite_node) const {
|
||||
// If there is only one output(ReduceSum), it should be a fake output with the same abstract with origin output.
|
||||
if (real_output_num_ <= 1) {
|
||||
void AtomicCleanInsertter::CorrectAbstract(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &process_infos) const {
|
||||
// If there is only one output, it should be a fake output with the same abstract with origin output.
|
||||
if (process_infos[0].first.real_output_num <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::set<size_t> reduce_real_indices;
|
||||
for (auto &info : process_infos) {
|
||||
reduce_real_indices.insert(info.first.reduce_real_output_index);
|
||||
}
|
||||
|
||||
// Change abstract.
|
||||
auto origin_out_spec = composite_node->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_out_spec);
|
||||
const auto &origin_out_specs = origin_out_spec->elements();
|
||||
AbstractBasePtrList new_out_specs;
|
||||
for (size_t i = 0; i < origin_out_specs.size(); ++i) {
|
||||
if (i != reduce_real_output_index_) {
|
||||
if (reduce_real_indices.count(i) == 0) {
|
||||
new_out_specs.push_back(origin_out_specs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// If empty, there will be a fake out, so use the first target reduce information.
|
||||
if (new_out_specs.empty()) {
|
||||
new_out_specs.push_back(origin_out_specs[process_infos[0].first.reduce_real_output_index]);
|
||||
}
|
||||
composite_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_out_specs));
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
|
||||
void AtomicCleanInsertter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
|
@ -342,23 +397,27 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
|
|||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
// Add atomic attribute to reducesum node.
|
||||
SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_node_);
|
||||
// Add input
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
|
||||
for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) {
|
||||
// Add atomic attribute to reducesum node.
|
||||
SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.atomic_add_node);
|
||||
// add parameter
|
||||
auto parameter = sub_graph->add_parameter();
|
||||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
parameters_infos.emplace_back(atomic_add_info, parameter);
|
||||
}
|
||||
|
||||
// add input
|
||||
auto inputs = composite_node->cast<CNodePtr>()->inputs();
|
||||
inputs.push_back(new_input);
|
||||
std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(), std::back_inserter(inputs),
|
||||
[](const std::pair<AtomicAddInfo, AnfNodePtr> &pair_item) { return pair_item.second; });
|
||||
composite_node->cast<CNodePtr>()->set_inputs(inputs);
|
||||
|
||||
// add parameter
|
||||
auto parameter = sub_graph->add_parameter();
|
||||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
|
||||
|
||||
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
|
||||
|
||||
CorrectAbstract(composite_node);
|
||||
CorrectKernelBuildInfo(composite_node, new_input);
|
||||
CorrectAbstract(composite_node, info_and_broadcast_to_nodes);
|
||||
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
|
@ -366,19 +425,6 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
|
|||
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
|
||||
const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) const {
|
||||
// Create depend node to hold execution order.
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(clean_node->abstract());
|
||||
main_graph->AddNode(depend_cnode);
|
||||
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(IntToSize(index), depend_cnode);
|
||||
}
|
||||
|
||||
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph,
|
||||
const CNodePtr &composite_node) const {
|
||||
// Insert update_state_node, need mount a monad node.
|
||||
|
@ -391,7 +437,8 @@ CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_grap
|
|||
return update_state_cnode;
|
||||
}
|
||||
|
||||
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
|
||||
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
|
||||
const KernelGraphPtr &main_graph, TypeId dst_type) {
|
||||
std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
|
||||
if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
|
||||
|
@ -399,7 +446,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
|
|||
}
|
||||
|
||||
// Create zero value which will be broadcast to target shape.
|
||||
auto format = GetFormat(atomic_add_node_);
|
||||
auto format = GetFormat(atomic_add_info.atomic_add_node);
|
||||
auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
|
||||
ValueNodePtr value_node;
|
||||
if (dtype == kNumberTypeFloat32) {
|
||||
|
@ -418,18 +465,19 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
|
|||
AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
|
||||
auto cast_node_inner =
|
||||
CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
|
||||
SetNodeAttrSafely(kAttrDstType, kFloat32, cast_node_inner);
|
||||
SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
|
||||
broadcast_input_node = cast_node_inner;
|
||||
} else {
|
||||
broadcast_input_node = value_node;
|
||||
}
|
||||
|
||||
// Create broadcast basic op.
|
||||
auto dst_shape_vec = GetShape(atomic_add_node_);
|
||||
auto dst_shape_vec = GetShape(atomic_add_info.atomic_add_node);
|
||||
AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
|
||||
auto broadcast_to_node_inner = CreateCNode(
|
||||
atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)});
|
||||
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_node_)), broadcast_to_node_inner);
|
||||
auto broadcast_to_node_inner =
|
||||
CreateCNode(atomic_clean_inputs, new_sub_graph,
|
||||
{.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_info.atomic_add_node)});
|
||||
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_info.atomic_add_node)), broadcast_to_node_inner);
|
||||
|
||||
// Makeup sub-graph.
|
||||
new_sub_graph->set_output(broadcast_to_node_inner);
|
||||
|
@ -443,17 +491,26 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
|
|||
return broadcast_to_composite_node;
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
|
||||
const AnfNodePtr &composite_node,
|
||||
const FuncGraphManagerPtr &mng,
|
||||
bool correct_index) const {
|
||||
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes;
|
||||
if (real_output_num_ <= 1) {
|
||||
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> AtomicCleanInsertter::FindOriginCNodeUsers(
|
||||
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes, const FuncGraphManagerPtr &mng,
|
||||
bool correct_index) const {
|
||||
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> reduce_user_nodes;
|
||||
|
||||
std::set<int64_t> real_indices;
|
||||
std::map<size_t, AnfNodePtr> real_indices_and_clean_node;
|
||||
for (auto &[info, clean] : info_and_broadcast_to_nodes) {
|
||||
real_indices_and_clean_node.insert({info.reduce_real_output_index, clean});
|
||||
real_indices.insert(SizeToLong(info.reduce_real_output_index));
|
||||
}
|
||||
|
||||
if (info_and_broadcast_to_nodes[0].first.real_output_num <= 1) {
|
||||
auto users = mng->node_users()[composite_node];
|
||||
(void)std::transform(users.cbegin(), users.cend(), std::back_inserter(reduce_user_nodes),
|
||||
[](const std::pair<AnfNodePtr, int> &pair) { return pair; });
|
||||
for (const auto &[user, index] : users) {
|
||||
reduce_user_nodes.emplace_back(user, index, info_and_broadcast_to_nodes[0].second);
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<AnfNodePtr, int> > getitem_user_nodes;
|
||||
std::vector<std::tuple<AnfNodePtr, AnfNodePtr>> getitem_user_nodes;
|
||||
auto users = mng->node_users()[composite_node];
|
||||
for (const auto &node_index : users) {
|
||||
const auto &user_node = node_index.first;
|
||||
|
@ -466,47 +523,43 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs
|
|||
auto value_node = value_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto item_idx = GetValue<int64_t>(value_node->value());
|
||||
if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) {
|
||||
getitem_user_nodes.push_back(node_index);
|
||||
auto iter = real_indices_and_clean_node.find(IntToSize(item_idx));
|
||||
if (iter != real_indices_and_clean_node.end()) {
|
||||
getitem_user_nodes.push_back({node_index.first, iter->second});
|
||||
} else if (correct_index) {
|
||||
if (real_output_num_ > 2) {
|
||||
// Recorrect other getitem index.
|
||||
int64_t new_item_idx = CalNewIndex(item_idx, SizeToLong(reduce_real_output_index_));
|
||||
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node,
|
||||
NewValueNode(new_item_idx)};
|
||||
auto new_out = main_graph->NewCNode(new_inputs);
|
||||
new_out->set_abstract(get_item_cnode->abstract());
|
||||
for (const auto &[user, index] : mng->node_users()[get_item_cnode]) {
|
||||
auto user_cnode = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(IntToSize(index), new_out);
|
||||
}
|
||||
} else {
|
||||
for (const auto &[user, index] : mng->node_users()[node_index.first]) {
|
||||
auto user_cnode = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(IntToSize(index), composite_node);
|
||||
}
|
||||
// Recorrect other getitem index.
|
||||
int64_t new_item_idx = CalNewIndex(item_idx, real_indices);
|
||||
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimTupleGetItem), composite_node, NewValueNode(new_item_idx)};
|
||||
auto new_out = main_graph->NewCNode(new_inputs);
|
||||
new_out->set_abstract(get_item_cnode->abstract());
|
||||
for (const auto &[user, index] : mng->node_users()[get_item_cnode]) {
|
||||
auto user_cnode = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(IntToSize(index), new_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &pair : getitem_user_nodes) {
|
||||
for (auto &[getitem_node, broadcast_to_node] : getitem_user_nodes) {
|
||||
// Directory to find real user.
|
||||
auto real_users = mng->node_users()[pair.first];
|
||||
(void)reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end());
|
||||
auto real_users = mng->node_users()[getitem_node];
|
||||
std::transform(real_users.cbegin(), real_users.cend(), std::back_inserter(reduce_user_nodes),
|
||||
[&broadcast_to_node](const std::pair<AnfNodePtr, int> &pair) {
|
||||
return std::make_tuple(pair.first, pair.second, broadcast_to_node);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return reduce_user_nodes;
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &broadcast_to_node,
|
||||
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
|
||||
void AtomicCleanInsertter::ProcessOriginCNodeUser(
|
||||
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
|
||||
// 1. find users, change getitem index if needed.
|
||||
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
|
||||
FindOriginCNodeUsers(main_graph, composite_node, mng, true);
|
||||
for (const auto &[user_node, index] : reduce_user_nodes) {
|
||||
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> reduce_user_nodes =
|
||||
FindOriginCNodeUsers(main_graph, composite_node, info_and_broadcast_to_nodes, mng, true);
|
||||
for (const auto &[user_node, index, broadcast_to_node] : reduce_user_nodes) {
|
||||
// 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect
|
||||
// update_state_node, broadcat_node and load_node to keep order.
|
||||
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node};
|
||||
|
@ -519,50 +572,38 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
|
|||
}
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::UpdateAtomicAddInfo(const AtomicAddInfo &atomic_add_info) {
|
||||
atomic_add_node_ = atomic_add_info.atomic_add_node;
|
||||
reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
|
||||
real_output_num_ = atomic_add_info.real_output_num;
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
auto origin_composite_node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_composite_node);
|
||||
|
||||
// Create broadcst node.
|
||||
auto out_type = GetType(atomic_add_node_)->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_type);
|
||||
auto broadcast_to_node = CreateAtomicCleanCompositeNode(main_graph, out_type->element()->type_id());
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_broadcast_to_nodes;
|
||||
for (auto atomic_add_info : atomic_add_infos) {
|
||||
auto out_type = GetType(atomic_add_info.atomic_add_node)->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_type);
|
||||
auto broadcast_to_node =
|
||||
CreateAtomicCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id());
|
||||
info_and_broadcast_to_nodes.emplace_back(atomic_add_info, broadcast_to_node);
|
||||
}
|
||||
|
||||
// Insert extra input(broadcast node output) to composite node, and make Reducesum inplaceassign to it.
|
||||
// Note: if it's single output, this will increase total memory because of a fake out.
|
||||
ProcessOriginCNode(origin_composite_node, broadcast_to_node);
|
||||
ProcessOriginCNode(origin_composite_node, info_and_broadcast_to_nodes);
|
||||
|
||||
// Insert update_state_node to keep execution order.
|
||||
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
|
||||
|
||||
// Replace origin ReduceSum's user with atomic clean output
|
||||
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng);
|
||||
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
|
||||
<< ", clean node: " << broadcast_to_node->fullname_with_scope();
|
||||
}
|
||||
ProcessOriginCNodeUser(main_graph, origin_composite_node, info_and_broadcast_to_nodes, update_state_node, mng);
|
||||
std::stringstream ss;
|
||||
ss << "Target node: " << origin_composite_node->fullname_with_scope() << ", clean nodes: ";
|
||||
for (auto iter : info_and_broadcast_to_nodes) {
|
||||
ss << iter.second->fullname_with_scope() << ", ";
|
||||
}
|
||||
|
||||
bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
|
||||
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
|
||||
// node and user node. If reduce is Depend node, the origin node may be wrong!
|
||||
return std::all_of(
|
||||
reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
|
||||
auto &user = user_info.first;
|
||||
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) &&
|
||||
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
});
|
||||
MS_LOG(INFO) << ss.str();
|
||||
}
|
||||
|
||||
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
||||
|
@ -582,13 +623,12 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
|||
|
||||
auto topo_nodes = TopoSort(kernel_graph->get_return());
|
||||
for (const auto &node : topo_nodes) {
|
||||
if (!atomic_add_checker->Check(node) || !IsExistStructuralObstacle(kernel_graph, node, mng)) {
|
||||
if (!atomic_add_checker->Check(node)) {
|
||||
continue;
|
||||
}
|
||||
auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
|
||||
InsertAtomicClean(kernel_graph, node, atomic_add_infos, mng);
|
||||
changed = true;
|
||||
auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
|
||||
UpdateAtomicAddInfo(atomic_add_info);
|
||||
InsertAtomicClean(kernel_graph, node, mng);
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
|
|
|
@ -39,13 +39,13 @@ class AtomicAddChecker {
|
|||
static std::shared_ptr<AtomicAddChecker> Init();
|
||||
|
||||
bool Check(const AnfNodePtr &node);
|
||||
AtomicAddInfo GetAtomicAddInfo() { return atomic_add_info_; }
|
||||
std::vector<AtomicAddInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
|
||||
|
||||
protected:
|
||||
virtual bool SuitableForAtomicAdd(const AnfNodePtr &node) { return false; }
|
||||
virtual bool FindCandidate(const AnfNodePtr &anf_node);
|
||||
virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
|
||||
AtomicAddInfo atomic_add_info_;
|
||||
std::vector<AtomicAddInfo> atomic_add_infos_;
|
||||
PrimitivePtr target_type_{prim::kPrimReduceSum};
|
||||
};
|
||||
|
||||
|
@ -74,31 +74,28 @@ class AtomicCleanInsertter : public opt::Pass {
|
|||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
|
||||
bool bypass = true);
|
||||
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
|
||||
virtual CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
|
||||
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &user_node, int index) const;
|
||||
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
|
||||
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos);
|
||||
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes);
|
||||
virtual CNodePtr CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
|
||||
const KernelGraphPtr &main_graph, TypeId dst_type);
|
||||
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
|
||||
CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) const;
|
||||
void CorrectAbstract(const AnfNodePtr &composite_node) const;
|
||||
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
|
||||
void CorrectAbstract(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &process_infos) const;
|
||||
void CreateInplaceAssignNodeAndCorrectReturn(
|
||||
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> ¶meters_infos);
|
||||
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
void UpdateAtomicAddInfo(const AtomicAddInfo &info);
|
||||
CNodePtr atomic_add_node_{nullptr};
|
||||
size_t reduce_real_output_index_{0};
|
||||
size_t real_output_num_{0};
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng);
|
||||
|
||||
private:
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
|
||||
const AnfNodePtr &composite_node,
|
||||
const FuncGraphManagerPtr &mng,
|
||||
bool correct_index) const;
|
||||
bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
std::vector<std::tuple<AnfNodePtr, int, AnfNodePtr>> FindOriginCNodeUsers(
|
||||
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const FuncGraphManagerPtr &mng, bool correct_index) const;
|
||||
};
|
||||
using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>;
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -30,20 +30,74 @@
|
|||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
|
||||
// Change kernel build info.
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
|
||||
auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
|
||||
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
|
||||
auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
|
||||
auto origin_processor = origin_kernel_build_info->processor();
|
||||
|
||||
std::vector<std::string> &new_inputs_format = origin_inputs_format;
|
||||
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
|
||||
std::vector<std::string> new_outputs_format;
|
||||
std::vector<TypeId> new_outputs_type;
|
||||
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
|
||||
new_outputs_format.push_back(origin_outputs_format[i]);
|
||||
new_outputs_type.push_back(origin_outputs_type[i]);
|
||||
}
|
||||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(clean_infos[0].second, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
|
||||
new_info_builder.SetInputsFormat(new_inputs_format);
|
||||
new_info_builder.SetInputsDeviceType(new_inputs_type);
|
||||
new_info_builder.SetOutputsFormat(new_outputs_format);
|
||||
new_info_builder.SetOutputsDeviceType(new_outputs_type);
|
||||
new_info_builder.SetProcessor(origin_processor);
|
||||
new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto new_selected_info = new_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
|
||||
const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
|
||||
int index) const {
|
||||
// Create depend node to hold execution order.
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(clean_node->abstract());
|
||||
main_graph->AddNode(depend_cnode);
|
||||
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(IntToSize(index), depend_cnode);
|
||||
}
|
||||
|
||||
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
|
||||
const AnfNodePtr &new_parameter) const {
|
||||
const AnfNodePtr &new_parameter,
|
||||
const AtomicAddInfo &info) const {
|
||||
// add inplaceassign
|
||||
AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
|
||||
AnfNodePtr out_node = info.atomic_add_node; // Use result data itself, and set attr "fake_out" true.
|
||||
auto inplace_assign_node =
|
||||
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
|
||||
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, out_node, out_node}, sub_graph,
|
||||
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
|
||||
SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
|
||||
AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_);
|
||||
AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
|
||||
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
|
||||
return inplace_assign_node;
|
||||
}
|
||||
|
||||
void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
|
||||
void StitchAtomicCleanInsertter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
|
@ -51,6 +105,8 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
|
|||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
auto [atomic_add_info, new_input] = info_and_broadcast_to_nodes[0];
|
||||
|
||||
// add input
|
||||
auto inputs = composite_node->cast<CNodePtr>()->inputs();
|
||||
inputs.push_back(new_input);
|
||||
|
@ -61,11 +117,12 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
|
|||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
|
||||
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter);
|
||||
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter, atomic_add_info);
|
||||
|
||||
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
|
||||
// elimination.
|
||||
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_);
|
||||
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes =
|
||||
FindInnerCNodeUsers(stitch_node_, atomic_add_info.atomic_add_node);
|
||||
bool connected = false;
|
||||
for (const auto &[user_node, index] : reduce_user_nodes) {
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
|
@ -79,7 +136,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
|
|||
}
|
||||
connected = true;
|
||||
}
|
||||
CorrectKernelBuildInfo(composite_node, new_input, false);
|
||||
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
|
||||
}
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
|
@ -105,8 +162,8 @@ std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNo
|
|||
return inner_user_nodes;
|
||||
}
|
||||
|
||||
bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
|
||||
if (!AnfAlgo::IsGraphKernel(anf_node)) return false;
|
||||
std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
|
||||
if (!AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()};
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
|
@ -116,12 +173,13 @@ bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node)
|
|||
if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
|
||||
AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
|
||||
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
|
||||
atomic_add_node_ = n->cast<CNodePtr>();
|
||||
AtomicAddInfo info;
|
||||
info.atomic_add_node = n->cast<CNodePtr>();
|
||||
stitch_node_ = anf_node;
|
||||
return true;
|
||||
return {true, info};
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return {false, AtomicAddInfo()};
|
||||
}
|
||||
|
||||
bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
||||
|
@ -137,8 +195,9 @@ bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
|||
auto topo_nodes = TopoSort(kernel_graph->get_return());
|
||||
for (const auto &node : topo_nodes) {
|
||||
// if stitch attr exists, add atomic clean op depends on the attr
|
||||
if (IsStitchWithAtomic(node)) {
|
||||
InsertAtomicClean(kernel_graph, node, mng);
|
||||
auto [is_stitch, atomic_add_info] = IsStitchWithAtomic(node);
|
||||
if (is_stitch) {
|
||||
InsertAtomicClean(kernel_graph, node, {atomic_add_info}, mng);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,11 +33,19 @@ class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
|
|||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter) const;
|
||||
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) override;
|
||||
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
|
||||
const AtomicAddInfo &info) const;
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
|
||||
const CNodePtr &target) const;
|
||||
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override;
|
||||
bool IsStitchWithAtomic(const AnfNodePtr &anf_node);
|
||||
void ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) override;
|
||||
std::pair<bool, AtomicAddInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
|
||||
|
||||
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &user_node, int index) const;
|
||||
|
||||
AnfNodePtr stitch_node_{nullptr};
|
||||
};
|
||||
|
|
|
@ -49,6 +49,8 @@
|
|||
#include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
|
||||
#include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/rewrite_output_shape.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_recompute.h"
|
||||
#include "backend/optimizer/graph_kernel/reduce_fake_out_mem.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using opt::CommonSubexpressionElimination;
|
||||
|
@ -154,6 +156,14 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
|
|||
|
||||
PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(4, "highlevelopt2");
|
||||
|
||||
auto &flags = GraphKernelFlags::GetInstance();
|
||||
// Auto recompute according to local memory burst.
|
||||
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0);
|
||||
pm->AddPass(std::make_shared<GraphKernelRecompute>(), recompute_lv);
|
||||
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), recompute_lv);
|
||||
pm->AddPass(std::make_shared<MergeOutputForUpdateState>(), recompute_lv);
|
||||
|
||||
// Enable atomic add
|
||||
pm->AddPass(std::make_shared<AtomicCleanInsertter>(), OptLevel_2, is_gpu || is_ascend);
|
||||
|
||||
|
@ -197,6 +207,9 @@ PassManagerPtr GraphKernelOptimizer::PostProcess() const {
|
|||
pm->AddPass(std::make_shared<GetitemTuple>(), OptLevel_1);
|
||||
pm->AddPass(std::make_shared<RewriteOutputShape>(), OptLevel_1);
|
||||
|
||||
// Reduce fake output memory.
|
||||
pm->AddPass(std::make_shared<ReduceFakeOutMem>(), OptLevel_1);
|
||||
|
||||
// Add the new tensors to the kernel_graph
|
||||
pm->AddPass(std::make_shared<BindValueToGraph>(), OptLevel_1);
|
||||
return pm;
|
||||
|
|
|
@ -0,0 +1,636 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_recompute.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
int64_t GetGetitemIndex(const AnfNodePtr &getitem) {
|
||||
auto vnode = GetValueNode(getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
|
||||
return GetValue<int64_t>(vnode);
|
||||
}
|
||||
|
||||
AnfNodePtr GetOutput(const FuncGraphPtr &func_graph, size_t i) {
|
||||
auto output_node = func_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
||||
if (i + 1 >= output_node->size()) {
|
||||
MS_LOG(EXCEPTION) << i << " is out of range of MakeTuple's size " << output_node->size();
|
||||
}
|
||||
return output_node->input(i + 1);
|
||||
} else {
|
||||
if (i > 0) {
|
||||
MS_LOG(EXCEPTION) << "the graph is single output but i is not 0. it's " << i;
|
||||
}
|
||||
return output_node->cast<AnfNodePtr>();
|
||||
}
|
||||
}
|
||||
|
||||
bool IsExclude(const AnfNodePtr &node) {
|
||||
static std::vector<PrimitivePtr> excludes = {prim::kPrimReturn, prim::kPrimUpdateState, prim::kPrimLoad,
|
||||
prim::kPrimMakeTuple, prim::kPrimDepend};
|
||||
return std::any_of(excludes.begin(), excludes.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
enum class VisitType : char { FOLLOW, STOP };
|
||||
using VisitFunc = std::function<VisitType(const AnfNodePtr &)>;
|
||||
using NextFunc = std::function<AnfNodePtrList(const AnfNodePtr &)>;
|
||||
using ProcessFunc = std::function<void(const AnfNodePtr &)>;
|
||||
|
||||
void Dfs(const AnfNodePtr ¤t, const VisitFunc &visit_func, const NextFunc &next_func,
|
||||
const ProcessFunc &before_func, const ProcessFunc &after_func, std::set<AnfNodePtr> *visited) {
|
||||
if (visited->count(current) > 0) {
|
||||
return;
|
||||
}
|
||||
visited->insert(current);
|
||||
if (visit_func(current) != VisitType::FOLLOW) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto &next : next_func(current)) {
|
||||
before_func(next);
|
||||
Dfs(next, visit_func, next_func, before_func, after_func, visited);
|
||||
after_func(next);
|
||||
}
|
||||
}
|
||||
|
||||
OrderedMap<AnfNodePtr, AnfNodePtrList> CollectLinkPaths(const std::map<AnfNodePtr, MemorySize> &topo_indice,
|
||||
const OrderedSet<AnfNodePtr> &direct_users,
|
||||
MemorySize max_topo_user_index,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
std::stack<AnfNodePtr> cur_stack;
|
||||
OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths;
|
||||
auto TmpVisitFunc = [&topo_indice, max_topo_user_index](const AnfNodePtr &n) -> VisitType {
|
||||
if (IsExclude(n)) {
|
||||
return VisitType::STOP;
|
||||
}
|
||||
|
||||
auto iter = topo_indice.find(n);
|
||||
if (iter == topo_indice.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find " << n->fullname_with_scope() << " in topo indices!";
|
||||
}
|
||||
if (iter->second > max_topo_user_index) {
|
||||
return VisitType::STOP;
|
||||
}
|
||||
return VisitType::FOLLOW;
|
||||
};
|
||||
|
||||
auto TmpNextFunc = [&mng](const AnfNodePtr &n) -> AnfNodePtrList {
|
||||
auto users = mng->node_users()[n];
|
||||
AnfNodePtrList nexts;
|
||||
std::transform(users.cbegin(), users.cend(), std::back_inserter(nexts),
|
||||
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
|
||||
return nexts;
|
||||
};
|
||||
|
||||
auto TmpBeforeFunc = [&link_paths, &cur_stack, &direct_users](const AnfNodePtr &next) -> void {
|
||||
if (direct_users.count(next) == 0) {
|
||||
return;
|
||||
}
|
||||
auto cur_node = cur_stack.top();
|
||||
if (link_paths.find(cur_node) == link_paths.end()) {
|
||||
link_paths.insert({cur_node, AnfNodePtrList()});
|
||||
}
|
||||
link_paths[cur_node].push_back(next);
|
||||
cur_stack.push(next);
|
||||
};
|
||||
|
||||
auto TmpAfterFunc = [&cur_stack, &direct_users](const AnfNodePtr &next) -> void {
|
||||
if (direct_users.count(next) == 0) {
|
||||
return;
|
||||
}
|
||||
cur_stack.push(next);
|
||||
};
|
||||
|
||||
std::set<AnfNodePtr> visited;
|
||||
for (auto user : direct_users) {
|
||||
cur_stack.push(user);
|
||||
Dfs(user, TmpVisitFunc, TmpNextFunc, TmpBeforeFunc, TmpAfterFunc, &visited);
|
||||
cur_stack.pop();
|
||||
}
|
||||
|
||||
return link_paths;
|
||||
}
|
||||
|
||||
OrderedSet<AnfNodePtr> GetLongTermNodes(const AnfNodePtrList &nodes, const AnfNodePtr &end_node,
|
||||
const std::map<AnfNodePtr, MemorySize> &topo_indices,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
OrderedSet<AnfNodePtr> long_term_nodes;
|
||||
for (auto node : nodes) {
|
||||
auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first;
|
||||
// Parameter or value have long term tensors.
|
||||
if (!utils::isa<CNodePtr>(real_node)) {
|
||||
long_term_nodes.insert(node);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto users = mng->node_users()[node];
|
||||
if (std::any_of(users.cbegin(), users.cend(), [&topo_indices, &end_node](const std::pair<AnfNodePtr, int> &user) {
|
||||
auto user_topo = topo_indices.find(user.first);
|
||||
auto end_topo = topo_indices.find(end_node);
|
||||
return user_topo->second >= end_topo->second;
|
||||
})) {
|
||||
long_term_nodes.insert(node);
|
||||
}
|
||||
}
|
||||
return long_term_nodes;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Remove real input which is not used and change the related graph parameters.
|
||||
*
|
||||
* @param func_graph Graph.
|
||||
* @param inputs Real inputs for graph cnode.
|
||||
*/
|
||||
void ElimRedundantInputsAndGraphParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
const auto &ori_parameter = func_graph->parameters();
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
std::set<AnfNodePtr> used_param;
|
||||
for (auto node : nodes) {
|
||||
if (node->isa<Parameter>()) {
|
||||
(void)used_param.insert(node);
|
||||
}
|
||||
}
|
||||
if (used_param.size() == ori_parameter.size()) {
|
||||
return;
|
||||
}
|
||||
AnfNodePtrList new_parameter, new_inputs;
|
||||
for (size_t i = 0; i < ori_parameter.size(); ++i) {
|
||||
if (used_param.count(ori_parameter[i])) {
|
||||
new_parameter.push_back(ori_parameter[i]);
|
||||
new_inputs.push_back((*inputs)[i]);
|
||||
}
|
||||
}
|
||||
func_graph->set_parameters(new_parameter);
|
||||
*inputs = std::move(new_inputs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* @brief Filter the input tensor(that live longer than end node) out and return valid inputs for memory calculation. \n
|
||||
* If the topo indices of the input's user is at least one greater than end_node, \n
|
||||
* it will retain when after end_node's execution.
|
||||
*
|
||||
* @param source_node
|
||||
* @param end_node
|
||||
* @param edge_pos
|
||||
* @param mng
|
||||
* @return AnfNodePtrList
|
||||
*/
|
||||
AnfNodePtrList AutoRecompute::Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
auto source_cnode = source_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(source_cnode);
|
||||
AnfNodePtrList node_inputs(source_cnode->inputs().begin() + 1, source_cnode->inputs().end());
|
||||
OrderedSet<AnfNodePtr> long_term_inputs = GetLongTermNodes(node_inputs, end_node, topo_indice_, mng);
|
||||
|
||||
AnfNodePtrList check_inputs;
|
||||
if (IsPrimitiveCNode(end_node->cast<CNodePtr>()->input(edge_pos), prim::kPrimTupleGetItem)) {
|
||||
auto out_index = GetSourceLinkOutPos(end_node, edge_pos);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(source_node);
|
||||
auto out = sub_graph->output();
|
||||
if (!IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Filte input tensor error";
|
||||
}
|
||||
|
||||
// Find subgraph's input according to edge node.
|
||||
auto start_node = out->cast<CNodePtr>()->input(out_index + 1);
|
||||
AnfNodePtrList sub_input_parameters;
|
||||
std::queue<AnfNodePtr> node_q;
|
||||
node_q.push(start_node);
|
||||
while (!node_q.empty()) {
|
||||
auto cur = node_q.front();
|
||||
node_q.pop();
|
||||
if (utils::isa<ParameterPtr>(cur)) {
|
||||
sub_input_parameters.push_back(cur);
|
||||
}
|
||||
auto cur_cnode = cur->cast<CNodePtr>();
|
||||
if (cur_cnode) {
|
||||
for (size_t i = 1; i < cur_cnode->inputs().size(); ++i) {
|
||||
node_q.push(cur_cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filte input that user's topo index is great than source graph.
|
||||
for (auto para : sub_input_parameters) {
|
||||
for (size_t i = 0; i < sub_graph->parameters().size(); ++i) {
|
||||
if (para == sub_graph->parameters()[i]) {
|
||||
check_inputs.push_back(node_inputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
check_inputs = node_inputs;
|
||||
}
|
||||
|
||||
AnfNodePtrList res;
|
||||
for (auto input : check_inputs) {
|
||||
if (long_term_inputs.count(input) == 0) {
|
||||
res.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get valid users information by giving node, excluding TupleGetItem, Load and so on.
|
||||
*/
|
||||
std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> AutoRecompute::GetValidUsers(
|
||||
const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
auto &user_map = mng->node_users();
|
||||
auto users = user_map[node];
|
||||
MemorySize max_topo_user_index = 0;
|
||||
std::queue<std::pair<AnfNodePtr, int>> users_queue;
|
||||
for (auto user_index : users) {
|
||||
users_queue.push(user_index);
|
||||
}
|
||||
OrderedSet<AnfNodePtr> direct_users;
|
||||
OutPosLinkMap user_edge_pos;
|
||||
while (!users_queue.empty()) {
|
||||
auto [user, index] = users_queue.front();
|
||||
users_queue.pop();
|
||||
if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
|
||||
for (auto get_item_user : user_map[user]) {
|
||||
users_queue.push(get_item_user);
|
||||
}
|
||||
continue;
|
||||
} else if (IsExclude(user)) {
|
||||
continue;
|
||||
}
|
||||
user_edge_pos[user].push_back(index);
|
||||
direct_users.insert(user);
|
||||
// Update maximum topo value.
|
||||
if (topo_indice_[user] > max_topo_user_index) {
|
||||
max_topo_user_index = topo_indice_[user];
|
||||
}
|
||||
}
|
||||
|
||||
return {direct_users, user_edge_pos, max_topo_user_index};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Judege target node for recompute according to current node, and capture source node information when find \n
|
||||
* target. There two type for tensor of the edge between source node and target node, example: \n
|
||||
* source ──[Short-Term]── A ── other \n
|
||||
* │ │ \n
|
||||
* └───────[Long-Term]────── target \n
|
||||
* For this example, \n
|
||||
* 1. There are two path from source node to target node, and target is directly user for source node, \n
|
||||
* so the tensor of their edge is a long-term tensor. \n
|
||||
* 2. From source node to A, there is only one path, and A is directly user for source node, \n
|
||||
* so the tensor of their edge is a short-term tensor.
|
||||
*
|
||||
* @param node Source node.
|
||||
* @param mng Graph manager.
|
||||
* @return OutPosLinkList Vector[Tuple(target node, input positions of target node for edge, edge type)].
|
||||
*/
|
||||
OutPosLinkList AutoRecompute::JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
auto [direct_users, user_edge_pos, max_topo_user_index] = GetValidUsers(node, mng);
|
||||
OutPosLinkList target_link_infos;
|
||||
OrderedSet<AnfNodePtr> long_term_users;
|
||||
// If the number of direct users is less than 2, there will no side way to its user....
|
||||
if (direct_users.size() >= 2) {
|
||||
OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths =
|
||||
CollectLinkPaths(topo_indice_, direct_users, max_topo_user_index, mng);
|
||||
for (const auto &[source, paths] : link_paths) {
|
||||
for (auto target : paths) {
|
||||
if (target != source) {
|
||||
target_link_infos.emplace_back(target, user_edge_pos[target], EdgeLifeTimeType::LongTerm);
|
||||
long_term_users.insert(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Direct users include long term users and short term users.
|
||||
// If the short term user is graph kernel composite node, it may be absorb and reduce the local peak memory.
|
||||
for (auto user : direct_users) {
|
||||
if (long_term_users.count(user) == 0 && AnfAlgo::IsGraphKernel(user)) {
|
||||
target_link_infos.emplace_back(user, user_edge_pos[user], EdgeLifeTimeType::ShortTerm);
|
||||
}
|
||||
}
|
||||
return target_link_infos;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get position of edge tensor between source node and target node. \n
|
||||
* For example, giving target node and edge position 0, will return 1: \n
|
||||
* source node \n
|
||||
* [0] [1] [2] <- output position \n
|
||||
* | \n
|
||||
* | \n
|
||||
* / \n
|
||||
* [0] [1] <- input position \n
|
||||
* target node
|
||||
*
|
||||
* @param target Target node.
|
||||
* @param pos The input position of target node for edge.
|
||||
* @return int The output position of source node for edge.
|
||||
*/
|
||||
int AutoRecompute::GetSourceLinkOutPos(const AnfNodePtr &target, int pos) {
|
||||
// If the input is get-item, than use get-item's index, otherwise zero.
|
||||
auto cnode = target->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prenode = cnode->input(pos);
|
||||
if (!IsPrimitiveCNode(prenode, prim::kPrimTupleGetItem)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto get_item_cnode = prenode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(get_item_cnode);
|
||||
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(value_input);
|
||||
auto value_node = value_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return static_cast<int>(GetValue<int64_t>(value_node->value()));
|
||||
}
|
||||
|
||||
MemorySize AutoRecompute::SelectThreshold(EdgeLifeTimeType type) {
|
||||
MemorySize threshold = 0;
|
||||
switch (type) {
|
||||
case EdgeLifeTimeType::ShortTerm:
|
||||
threshold = local_peak_threshold_;
|
||||
break;
|
||||
case EdgeLifeTimeType::LongTerm:
|
||||
threshold =
|
||||
local_peak_threshold_ == 0 ? lifetime_threshold_ : std::min(local_peak_threshold_, lifetime_threshold_);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unknown edge type!";
|
||||
}
|
||||
|
||||
return threshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find recompute candidates(source node, target node, edge and its type) in func_graph. \n
|
||||
* Result will be add to candidates_.
|
||||
*
|
||||
* @param func_graph
|
||||
*/
|
||||
void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
candidates_.clear();
|
||||
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
// Do thing when threshold is default value 0.
|
||||
if (SelectThreshold(EdgeLifeTimeType::ShortTerm) == 0 && SelectThreshold(EdgeLifeTimeType::LongTerm) == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto topo_nodes = TopoSort(func_graph->get_return());
|
||||
// Topo indice is use to early stop in predecessor check.
|
||||
for (size_t i = 0; i < topo_nodes.size(); ++i) {
|
||||
topo_indice_.insert({topo_nodes[i], i});
|
||||
}
|
||||
|
||||
// Candidate condition:
|
||||
// 1. Judge current node can see its graph_kernel input with other input's backward path.
|
||||
// 2. Memory variety between split out and origin more than threshold:
|
||||
// `Size(gs_direct_outs_to_gt) - filter(gs_inputs, its) > threshold`.
|
||||
for (auto node : topo_nodes) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto target_graphs = JudegeTargetAndCaptureSource(node, mng);
|
||||
if (target_graphs.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OrderedMap<AnfNodePtr, OrderedMap<AnfNodePtr, std::pair<EdgeLifeTimeType, AnfNodePtrList>>> tmp_candidates;
|
||||
for (auto [gt, gt_in_pos_vec, edge_life_time_type] : target_graphs) {
|
||||
MemorySize threshold = SelectThreshold(edge_life_time_type);
|
||||
for (auto gt_in_pos : gt_in_pos_vec) {
|
||||
MemorySize out_tensor_size =
|
||||
static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(node, GetSourceLinkOutPos(gt, gt_in_pos)));
|
||||
MemorySize absorb_input_tensor_size = 0;
|
||||
for (auto input : Filter(node, gt, gt_in_pos, mng)) {
|
||||
absorb_input_tensor_size += static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(input, 0));
|
||||
}
|
||||
auto gt_cnode = gt->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(gt_cnode);
|
||||
auto edge = gt_cnode->input(gt_in_pos);
|
||||
if (out_tensor_size < absorb_input_tensor_size) {
|
||||
continue;
|
||||
}
|
||||
if (out_tensor_size - absorb_input_tensor_size > threshold) {
|
||||
if (tmp_candidates[node].find(gt) == tmp_candidates[node].end()) {
|
||||
tmp_candidates[node][gt] = {edge_life_time_type, AnfNodePtrList{}};
|
||||
}
|
||||
// Only add getitem node as edge, if GS is single output node, there will be no edges.
|
||||
if (IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) {
|
||||
tmp_candidates[node][gt].second.push_back(edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete duplicated link.
|
||||
for (const auto &[source, target_and_link] : tmp_candidates) {
|
||||
for (const auto &[target, link] : target_and_link) {
|
||||
candidates_.push_back({source, target, link.first, link.second});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RecomputeCandidatesLog(candidates_);
|
||||
}
|
||||
|
||||
void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candidates) {
|
||||
MS_LOG(INFO) << "Recompute candidates: ";
|
||||
for (auto candidate : candidates) {
|
||||
MS_LOG(INFO) << " └─ GS: " << candidate.source_graph->fullname_with_scope();
|
||||
MS_LOG(INFO) << " └─ GT: " << candidate.target_graph->fullname_with_scope();
|
||||
for (auto edge : candidate.recompute_edges) {
|
||||
MS_LOG(INFO) << " └─[Edge]─> " << edge->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
|
||||
const AnfNodePtrList &recompute_edges) {
|
||||
MS_EXCEPTION_IF_NULL(source_graph);
|
||||
auto gs = AnfAlgo::GetCNodeFuncGraphPtr(source_graph);
|
||||
MS_EXCEPTION_IF_NULL(gs);
|
||||
AnfNodePtrList inputs(source_graph->inputs().begin() + 1, source_graph->inputs().end());
|
||||
auto new_funcgraph = BasicClone(gs);
|
||||
auto output_node = new_funcgraph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
if (!IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
||||
return {new_funcgraph, inputs};
|
||||
}
|
||||
// remove outputs that not in recompute edges.
|
||||
AnfNodePtrList new_outputs;
|
||||
for (auto &edge : recompute_edges) {
|
||||
auto idx = GetGetitemIndex(edge);
|
||||
new_outputs.push_back(GetOutput(new_funcgraph, idx));
|
||||
}
|
||||
if (new_outputs.size() + 1 == output_node->size()) {
|
||||
return {new_funcgraph, inputs};
|
||||
}
|
||||
new_outputs.insert(new_outputs.begin(), output_node->input(0));
|
||||
auto new_output_node = new_funcgraph->NewCNode(new_outputs);
|
||||
// use the old abstract, since the new_funcgraph will be deleted in later process.
|
||||
new_output_node->set_abstract(output_node->abstract());
|
||||
new_output_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
new_funcgraph->set_output(new_output_node);
|
||||
ElimRedundantInputsAndGraphParameters(new_funcgraph, &inputs);
|
||||
return {new_funcgraph, inputs};
|
||||
}
|
||||
|
||||
void GraphKernelRecompute::LinkIntoTargetFuncGraph(const Candidate &candidate, const FuncGraphPtr &cloned_func,
|
||||
const AnfNodePtrList &cloned_inputs) {
|
||||
auto cloned_nodes = TopoSort(cloned_func->get_return());
|
||||
auto gt = AnfAlgo::GetCNodeFuncGraphPtr(candidate.target_graph);
|
||||
MS_EXCEPTION_IF_NULL(gt);
|
||||
auto mng = gt->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(gt, true);
|
||||
gt->set_manager(mng);
|
||||
}
|
||||
|
||||
// link the outputs to gt
|
||||
auto gt_node = candidate.target_graph->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(gt_node);
|
||||
AnfNodePtrList new_parameters;
|
||||
AnfNodePtrList new_inputs;
|
||||
auto ¶ms = gt->parameters();
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
// if the parameter is a recompute edge, then links the param to the cloned_func's output.
|
||||
auto iter = std::find(candidate.recompute_edges.begin(), candidate.recompute_edges.end(), gt_node->input(i + 1));
|
||||
if (iter != candidate.recompute_edges.end()) {
|
||||
auto out_index = iter - candidate.recompute_edges.begin();
|
||||
mng->Replace(params[i], GetOutput(cloned_func, out_index));
|
||||
} else {
|
||||
new_parameters.push_back(params[i]);
|
||||
new_inputs.push_back(gt_node->input(i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
// add new parameters
|
||||
auto &cloned_func_params = cloned_func->parameters();
|
||||
for (size_t i = 0; i < cloned_func_params.size(); i++) {
|
||||
auto iter = std::find(new_inputs.begin(), new_inputs.end(), cloned_inputs[i]);
|
||||
if (iter != new_inputs.end()) {
|
||||
auto idx = iter - new_inputs.begin();
|
||||
cloned_func->manager()->Replace(cloned_func_params[i], new_parameters[idx]);
|
||||
} else {
|
||||
new_parameters.push_back(gt->add_parameter());
|
||||
new_inputs.push_back(cloned_inputs[i]);
|
||||
cloned_func->manager()->Replace(cloned_func_params[i], new_parameters.back());
|
||||
}
|
||||
}
|
||||
|
||||
// reset the func_graph for cloned_nodes.
|
||||
for (auto &node : cloned_nodes) {
|
||||
if (node->isa<CNode>()) {
|
||||
node->set_func_graph(gt);
|
||||
}
|
||||
}
|
||||
AnfNodePtrList new_node_inputs = {gt_node->input(0)};
|
||||
new_node_inputs.insert(new_node_inputs.end(), new_inputs.begin(), new_inputs.end());
|
||||
gt->set_parameters(new_parameters);
|
||||
gt_node->set_inputs(new_node_inputs);
|
||||
AnfNodePtrList outputs;
|
||||
kernel::GetFuncGraphOutputNodes(gt, &outputs);
|
||||
gt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
SetNewKernelInfo(gt_node, gt, new_inputs, outputs);
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({gt});
|
||||
}
|
||||
|
||||
void GraphKernelRecompute::Process(const Candidate &candidate) {
|
||||
auto gs = AnfAlgo::GetCNodeFuncGraphPtr(candidate.source_graph);
|
||||
MS_EXCEPTION_IF_NULL(gs);
|
||||
if (candidate.recompute_edges.empty()) {
|
||||
// single output, clone the whole source_graph.
|
||||
return;
|
||||
}
|
||||
|
||||
// AnfNodePtrList outputs;
|
||||
auto [new_funcgraph, inputs] = CloneGraph(candidate.source_graph->cast<CNodePtr>(), candidate.recompute_edges);
|
||||
|
||||
auto mng = new_funcgraph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(new_funcgraph, true);
|
||||
new_funcgraph->set_manager(mng);
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsGraphKernel(candidate.target_graph)) {
|
||||
// the target graph is a GraphKernel, push the new_funcgraph into the target graph.
|
||||
LinkIntoTargetFuncGraph(candidate, new_funcgraph, inputs);
|
||||
} else {
|
||||
// The target graph is not a GraphKernel, build the new_funcgraph to a CNode.
|
||||
MS_LOG(WARNING) << "Target node " << candidate.target_graph->fullname_with_scope()
|
||||
<< " is not a graph kernel node, cannot absort the link edge!";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
|
||||
int repeat_times = 2;
|
||||
while (repeat_times--) {
|
||||
candidates_ = AutoRecompute().Run(func_graph);
|
||||
if (candidates_.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (auto &c : candidates_) {
|
||||
if (!AnfAlgo::IsGraphKernel(c.target_graph)) {
|
||||
continue;
|
||||
}
|
||||
std::ostringstream oss;
|
||||
for (auto &e : c.recompute_edges) {
|
||||
if (!IsPrimitiveCNode(e, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(EXCEPTION) << "The edge should be GetItem but got " << e->fullname_with_scope();
|
||||
}
|
||||
oss << e->fullname_with_scope() << ", ";
|
||||
}
|
||||
MS_LOG(INFO) << "Clone " << c.source_graph->fullname_with_scope() << " to "
|
||||
<< c.target_graph->fullname_with_scope() << ", edges [" << oss.str() << "]";
|
||||
Process(c);
|
||||
}
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
|
||||
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
/*
|
||||
* Recompute some operator to reduce temporary memory peak.
|
||||
*
|
||||
* (a) (b) (a) (b)
|
||||
* \ / | |
|
||||
* Gs Gs1 |
|
||||
* (c)/ | (c)| |
|
||||
* / | Go |
|
||||
* Go |(d) =========> │└─depend
|
||||
* \ | │ │
|
||||
* (e)\ | (e)│ Gs2
|
||||
* \| │ │
|
||||
* Gt ├────(d)
|
||||
* Gt
|
||||
*
|
||||
* Where, split Gs to Gs1 and Gs2, and (x) means the temporary tensor.
|
||||
* For left graph, the memory is (a+b) -> (c+d) -> (d+e)
|
||||
* As for right graph, memory is (a+b) -> (b+c) -> (b+e) -> (d+e)
|
||||
* If the (c+d) reach the threshold memory, and (b+c) or (b+e) is less than it,
|
||||
* it may ease the memory burden.
|
||||
*/
|
||||
enum class EdgeLifeTimeType : char { ShortTerm, LongTerm };
|
||||
inline std::ostream &operator<<(std::ostream &os, EdgeLifeTimeType type) {
|
||||
std::map<EdgeLifeTimeType, std::string> out_str = {{EdgeLifeTimeType::ShortTerm, "[ShortTerm]"},
|
||||
{EdgeLifeTimeType::LongTerm, "[LongTerm]"}};
|
||||
return os << out_str[type];
|
||||
}
|
||||
using OutPosLinkList = std::vector<std::tuple<AnfNodePtr, std::vector<int>, EdgeLifeTimeType>>;
|
||||
using OutPosLinkMap = std::map<AnfNodePtr, std::vector<int>>;
|
||||
using MemorySize = int64_t;
|
||||
struct Candidate {
|
||||
AnfNodePtr source_graph;
|
||||
AnfNodePtr target_graph;
|
||||
EdgeLifeTimeType type;
|
||||
AnfNodePtrList recompute_edges; // getitem list for recompute edges.
|
||||
};
|
||||
|
||||
class AutoRecompute {
|
||||
public:
|
||||
std::vector<Candidate> Run(const FuncGraphPtr &func_graph) {
|
||||
lifetime_threshold_ = GraphKernelFlags::GetInstance().recompute_increment_threshold;
|
||||
local_peak_threshold_ = GraphKernelFlags::GetInstance().recompute_peak_threshold;
|
||||
FindCandidates(func_graph);
|
||||
return candidates_;
|
||||
}
|
||||
|
||||
private:
|
||||
OutPosLinkList JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
|
||||
AnfNodePtrList Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
void FindCandidates(const FuncGraphPtr &func_graph);
|
||||
int GetSourceLinkOutPos(const AnfNodePtr &target, int pos);
|
||||
std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> GetValidUsers(const AnfNodePtr &node,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
MemorySize SelectThreshold(EdgeLifeTimeType type);
|
||||
|
||||
std::map<AnfNodePtr, MemorySize> topo_indice_;
|
||||
std::vector<Candidate> candidates_;
|
||||
MemorySize lifetime_threshold_{0};
|
||||
MemorySize local_peak_threshold_{0};
|
||||
|
||||
void RecomputeCandidatesLog(const std::vector<Candidate> &candidates);
|
||||
};
|
||||
|
||||
class GraphKernelRecompute : public opt::Pass {
|
||||
public:
|
||||
GraphKernelRecompute() : Pass("graph_kernel_recompute") {}
|
||||
~GraphKernelRecompute() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void Process(const Candidate &candidate);
|
||||
std::pair<FuncGraphPtr, AnfNodePtrList> CloneGraph(const CNodePtr &source_graph,
|
||||
const AnfNodePtrList &recompute_edge);
|
||||
void LinkIntoTargetFuncGraph(const Candidate &candidate, const FuncGraphPtr &cloned_func,
|
||||
const AnfNodePtrList &cloned_inputs);
|
||||
|
||||
std::vector<Candidate> candidates_;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_RECOMPUTE_H_
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/reduce_fake_out_mem.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
constexpr auto kFakeOut = "fake_output";
|
||||
void ReduceFakeOutMem::ModifyAbstract(const AnfNodePtr &composite_node, const std::set<size_t> &fake_real_indices,
|
||||
const AnfNodePtrList &output_list) {
|
||||
if (fake_real_indices.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (output_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Output size should not be zero while there is at least one fake output!";
|
||||
}
|
||||
|
||||
std::vector<AbstractBasePtr> out_specs;
|
||||
for (size_t i = 0; i < output_list.size(); ++i) {
|
||||
if (fake_real_indices.count(i)) {
|
||||
std::vector<int64_t> shape_vec_shape = {1};
|
||||
AbstractBasePtr abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
|
||||
out_specs.push_back(abstract);
|
||||
} else {
|
||||
out_specs.push_back(output_list[i]->abstract());
|
||||
}
|
||||
}
|
||||
AbstractBasePtr out_spec;
|
||||
if (output_list.size() > 1) {
|
||||
out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
|
||||
} else {
|
||||
out_spec = output_list[0]->abstract();
|
||||
}
|
||||
composite_node->set_abstract(out_spec);
|
||||
}
|
||||
|
||||
bool ReduceFakeOutMem::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = false;
|
||||
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
|
||||
AnfNodePtrList output_list;
|
||||
kernel::GetFuncGraphOutputNodes(sub_graph, &output_list);
|
||||
std::set<size_t> fake_real_indices;
|
||||
for (size_t i = 0; i < output_list.size(); ++i) {
|
||||
auto &out = output_list[i];
|
||||
auto out_cnode = out->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
if (AnfAlgo::HasNodeAttr(kFakeOut, out_cnode) && AnfAlgo::GetNodeAttr<bool>(out_cnode, kFakeOut)) {
|
||||
fake_real_indices.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (fake_real_indices.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ModifyAbstract(node, fake_real_indices, output_list);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REDUCE_FAKE_OUT_MEM_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REDUCE_FAKE_OUT_MEM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
/**
|
||||
* @brief Reduce a fake output memory from origin memory size to 1.
|
||||
*/
|
||||
class ReduceFakeOutMem : public opt::Pass {
|
||||
public:
|
||||
ReduceFakeOutMem() : Pass("reduce_fake_output_memory") {}
|
||||
~ReduceFakeOutMem() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void ModifyAbstract(const AnfNodePtr &composite_node, const std::set<size_t> &fake_real_indices,
|
||||
const AnfNodePtrList &output_list);
|
||||
};
|
||||
using ReduceFakeOutMemPtr = std::shared_ptr<ReduceFakeOutMem>;
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REDUCE_FAKE_OUT_MEM_H_
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "base/core_ops.h"
|
||||
|
@ -41,16 +42,20 @@ class TsaChecker : public AtomicAddChecker {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto tsa_cnode = atomic_add_info_.atomic_add_node;
|
||||
if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
|
||||
return false;
|
||||
for (auto atomic_add_info : atomic_add_infos_) {
|
||||
auto tsa_cnode = atomic_add_info.atomic_add_node;
|
||||
if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const AnfNodePtr &node) {
|
||||
std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelGraphPtr &main_graph,
|
||||
const CNodePtr &tsa_node,
|
||||
const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
|
@ -60,13 +65,14 @@ AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelG
|
|||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
auto first_input = atomic_add_node_->input(1)->cast<ParameterPtr>();
|
||||
auto first_input = tsa_node->input(1)->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(first_input);
|
||||
auto parameters = sub_graph->parameters();
|
||||
bool hit = false;
|
||||
size_t tsa_first_input_index = 0;
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
if (parameters[i] == first_input) {
|
||||
tsa_first_input_index_ = i;
|
||||
tsa_first_input_index = i;
|
||||
hit = true;
|
||||
break;
|
||||
}
|
||||
|
@ -75,25 +81,31 @@ AnfNodePtr TsaAtomicAddToFirstTensor::FindTsaFirstRealInputInGraph(const KernelG
|
|||
MS_LOG(EXCEPTION) << "Cannot find tensor scatter add first input in sub-graph parameters!";
|
||||
}
|
||||
|
||||
return cnode->input(tsa_first_input_index_ + 1); // CNode input have a primitive, so add 1.
|
||||
return {cnode->input(tsa_first_input_index + 1), tsa_first_input_index}; // CNode input have a primitive, so add 1.
|
||||
}
|
||||
|
||||
AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &main_graph, const AnfNodePtr &node) {
|
||||
std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode(
|
||||
const KernelGraphPtr &main_graph, const AtomicAddInfo &atomic_add_info, const AnfNodePtr &node) {
|
||||
auto mng = main_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(main_graph, true);
|
||||
main_graph->set_manager(mng);
|
||||
}
|
||||
// find first input of tsa
|
||||
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, node);
|
||||
auto users = mng->node_users()[tsa_first_input];
|
||||
if (users.size() == 1 && !(utils::isa<ValueNodePtr>(tsa_first_input) || utils::isa<ParameterPtr>(tsa_first_input))) {
|
||||
|
||||
// Find first input of tsa
|
||||
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.atomic_add_node, node);
|
||||
auto users = mng->node_users()[tsa_first_input.first];
|
||||
if (users.size() == 1 &&
|
||||
!(utils::isa<ValueNodePtr>(tsa_first_input.first) || utils::isa<ParameterPtr>(tsa_first_input.first))) {
|
||||
// If current composite node is only user, and first input is not Parameter or Tensor Value, then use itself.
|
||||
return tsa_first_input;
|
||||
}
|
||||
|
||||
// Create a copy of first input to atomic add to.
|
||||
// Create composite op's sub-graph.
|
||||
auto new_sub_graph = std::make_shared<FuncGraph>();
|
||||
auto parameter = new_sub_graph->add_parameter();
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(tsa_first_input, 0);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(tsa_first_input.first, 0);
|
||||
parameter->set_abstract(GetOutputAbstract(kernel_with_index.first, kernel_with_index.second));
|
||||
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
std::string parameter_format;
|
||||
|
@ -123,18 +135,18 @@ AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &
|
|||
|
||||
// Makeup sub-graph.
|
||||
new_sub_graph->set_output(identity_node);
|
||||
auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input});
|
||||
new_composite_node->set_abstract(identity_node->abstract());
|
||||
SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node});
|
||||
auto new_copy_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input.first});
|
||||
new_copy_composite_node->set_abstract(identity_node->abstract());
|
||||
SetNewKernelInfo(new_copy_composite_node, new_sub_graph, {tsa_first_input.first}, {identity_node});
|
||||
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
|
||||
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
|
||||
new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
|
||||
|
||||
return new_composite_node;
|
||||
return {new_copy_composite_node, tsa_first_input.second};
|
||||
}
|
||||
|
||||
void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &modified_input, bool) {
|
||||
void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos) {
|
||||
// Change kernel build info with modify input
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -149,19 +161,29 @@ void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composi
|
|||
std::vector<TypeId> &modified_inputs_type = origin_inputs_type;
|
||||
std::vector<std::string> new_outputs_format;
|
||||
std::vector<TypeId> new_outputs_type;
|
||||
|
||||
std::set<size_t> reduce_real_indices;
|
||||
for (auto &info : outer_infos) {
|
||||
reduce_real_indices.insert(std::get<0>(info).reduce_real_output_index);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
|
||||
if (real_output_num_ > 1 && i == reduce_real_output_index_) {
|
||||
if (std::get<0>(outer_infos[0]).real_output_num > 1 && reduce_real_indices.count(i) > 0) {
|
||||
continue;
|
||||
}
|
||||
new_outputs_format.push_back(origin_outputs_format[i]);
|
||||
new_outputs_type.push_back(origin_outputs_type[i]);
|
||||
}
|
||||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(modified_input, 0);
|
||||
modified_inputs_format[tsa_first_input_index_] =
|
||||
AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
modified_inputs_type[tsa_first_input_index_] =
|
||||
AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
for (const auto &outer_info : outer_infos) {
|
||||
auto &modified_input = std::get<1>(outer_info);
|
||||
auto tsa_first_input_index = std::get<2>(outer_info);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(modified_input, 0);
|
||||
modified_inputs_format[tsa_first_input_index] =
|
||||
AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
modified_inputs_type[tsa_first_input_index] =
|
||||
AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
|
||||
new_info_builder.SetInputsFormat(modified_inputs_format);
|
||||
|
@ -175,7 +197,8 @@ void TsaAtomicAddToFirstTensor::CorrectKernelBuildInfo(const AnfNodePtr &composi
|
|||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &outter_node) {
|
||||
void TsaAtomicAddToFirstTensor::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
|
@ -183,12 +206,20 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n
|
|||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
// modify input
|
||||
composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index_ + 1, outter_node);
|
||||
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, sub_graph->parameters()[tsa_first_input_index_]);
|
||||
// Modify input
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_tsa_outers;
|
||||
for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) {
|
||||
composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index + 1, outer_node);
|
||||
auto parameter = sub_graph->parameters()[tsa_first_input_index];
|
||||
parameters_infos.emplace_back(atomic_add_info, parameter);
|
||||
info_and_tsa_outers.emplace_back(atomic_add_info, outer_node);
|
||||
}
|
||||
|
||||
CorrectAbstract(composite_node);
|
||||
CorrectKernelBuildInfo(composite_node, outter_node);
|
||||
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
|
||||
|
||||
CorrectAbstract(composite_node, info_and_tsa_outers);
|
||||
CorrectKernelBuildInfo(composite_node, outer_nodes);
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name =
|
||||
|
@ -198,24 +229,34 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n
|
|||
}
|
||||
|
||||
void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
auto origin_composite_node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_composite_node);
|
||||
|
||||
// Create identity node.
|
||||
auto outter_node = ProcessTsaFirstNode(main_graph, anf_node);
|
||||
// Create identity node. // Create broadcst node.
|
||||
std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_outer_nodes;
|
||||
for (auto atomic_add_info : atomic_add_infos) {
|
||||
auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node);
|
||||
info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second);
|
||||
info_and_outer_nodes.emplace_back(atomic_add_info, outer.first);
|
||||
}
|
||||
|
||||
// Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplaceassign to it.
|
||||
// Note: if it's single output, this will increase total memory because of a fake out.
|
||||
ProcessOriginCNode(origin_composite_node, outter_node);
|
||||
ProcessOriginCNode(origin_composite_node, info_and_outer_nodes_with_index);
|
||||
|
||||
// Insert update_state_node to keep execution order.
|
||||
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
|
||||
|
||||
// Replace origin ReduceSum's user with atomic clean output
|
||||
ProcessOriginCNodeUser(main_graph, origin_composite_node, outter_node, update_state_node, mng);
|
||||
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
|
||||
<< ", outer node: " << outter_node->fullname_with_scope();
|
||||
ProcessOriginCNodeUser(main_graph, origin_composite_node, info_and_outer_nodes, update_state_node, mng);
|
||||
std::stringstream ss;
|
||||
ss << "Target node: " << origin_composite_node->fullname_with_scope() << ", outer nodes: ";
|
||||
for (auto iter : info_and_outer_nodes) {
|
||||
ss << iter.second->fullname_with_scope() << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
|
||||
|
@ -239,11 +280,8 @@ bool TsaAtomicAddToFirstTensor::Run(const FuncGraphPtr &func_graph) {
|
|||
if (!atomic_add_checker->Check(node)) {
|
||||
continue;
|
||||
}
|
||||
auto atomic_add_info = atomic_add_checker->GetAtomicAddInfo();
|
||||
atomic_add_node_ = atomic_add_info.atomic_add_node;
|
||||
reduce_real_output_index_ = atomic_add_info.reduce_real_output_index;
|
||||
real_output_num_ = atomic_add_info.real_output_num;
|
||||
ProcessTsa(kernel_graph, node, mng);
|
||||
auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
|
||||
ProcessTsa(kernel_graph, node, atomic_add_infos, mng);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,12 +48,17 @@ class TsaAtomicAddToFirstTensor : public AtomicCleanInsertter {
|
|||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) override;
|
||||
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
|
||||
bool bypass = true) override;
|
||||
void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
|
||||
AnfNodePtr ProcessTsaFirstNode(const KernelGraphPtr &main_graph, const AnfNodePtr &node);
|
||||
AnfNodePtr FindTsaFirstRealInputInGraph(const KernelGraphPtr &main_graph, const AnfNodePtr &node);
|
||||
void ProcessOriginCNode(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes);
|
||||
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos);
|
||||
void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
|
||||
std::pair<AnfNodePtr, size_t> GetOrCreateNewTsaFirstNode(const KernelGraphPtr &main_graph,
|
||||
const AtomicAddInfo &atomic_add_info,
|
||||
const AnfNodePtr &node);
|
||||
std::pair<AnfNodePtr, size_t> FindTsaFirstRealInputInGraph(const KernelGraphPtr &main_graph, const CNodePtr &tsa_node,
|
||||
const AnfNodePtr &node);
|
||||
|
||||
size_t tsa_first_input_index_{0}; // sub-graph parameter index.
|
||||
};
|
||||
|
|
|
@ -133,6 +133,7 @@ bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
}
|
||||
if (changed) {
|
||||
UpdateMng(mng, func_graph);
|
||||
std::make_shared<SpreadUpdateState>()->Run(func_graph);
|
||||
std::make_shared<EliminateHangingOutput>()->Run(func_graph);
|
||||
}
|
||||
|
|
|
@ -55,9 +55,9 @@ bool UssAtomicAdd::Run(const FuncGraphPtr &func_graph) {
|
|||
if (!atomic_add_checker->Check(node)) {
|
||||
continue;
|
||||
}
|
||||
auto info = atomic_add_checker->GetAtomicAddInfo();
|
||||
UpdateAtomicAddInfo(info);
|
||||
InsertAtomicClean(kernel_graph, node, mng);
|
||||
|
||||
auto atomic_add_infos = atomic_add_checker->GetAtomicAddInfo();
|
||||
InsertAtomicClean(kernel_graph, node, atomic_add_infos, mng);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -230,6 +230,8 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
reg.AddFlag("online_tuning", &online_tuning);
|
||||
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_MAX);
|
||||
reg.AddFlag("parallel_ops_level", ¶llel_ops_level);
|
||||
reg.AddFlag("recompute_increment_threshold", &recompute_increment_threshold);
|
||||
reg.AddFlag("recompute_peak_threshold", &recompute_peak_threshold);
|
||||
|
||||
// String flags
|
||||
reg.AddFlag("repository_path", &repository_path);
|
||||
|
@ -261,6 +263,8 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
json["fusion_ops_level"] = fusion_ops_level;
|
||||
json["parallel_ops_level"] = parallel_ops_level;
|
||||
json["online_tuning"] = online_tuning;
|
||||
json["recompute_increment_threshold"] = recompute_increment_threshold;
|
||||
json["recompute_peak_threshold"] = recompute_peak_threshold;
|
||||
|
||||
json["repository_path"] = repository_path;
|
||||
|
||||
|
|
|
@ -127,6 +127,16 @@ class GraphKernelFlags {
|
|||
*/
|
||||
unsigned int online_tuning{0};
|
||||
|
||||
/**
|
||||
* Threshold for detection of recopute's memory increment case.
|
||||
*/
|
||||
int64_t recompute_increment_threshold{0};
|
||||
|
||||
/**
|
||||
* Threshold for detection of recopute's memory peak case.
|
||||
*/
|
||||
int64_t recompute_peak_threshold{0};
|
||||
|
||||
/**
|
||||
* AKG's operator repository file path.
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue