refactor atomic add

This commit is contained in:
Yang Jiao 2022-04-16 20:12:30 +08:00
parent 2174f71616
commit 89e6d40524
10 changed files with 359 additions and 292 deletions

View File

@ -166,19 +166,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0);
pm->Add(std::make_shared<GraphKernelRecompute>(), recompute_lv);
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
pm->Add(std::make_shared<OptimizeAssign>(), OptLevel_2);
pm->Add(std::make_shared<ExtendOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
pm->Add(std::make_shared<MergeOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
pm->Add(std::make_shared<EliminateRedundantOutput>(), std::min(recompute_lv, OptLevel_2));
// Enable atomic add
pm->Add(std::make_shared<AtomicCleanInsertter>(), OptLevel_2, is_gpu || is_ascend);
pm->Add(std::make_shared<AtomicCleanInserter>(), OptLevel_2, is_gpu || is_ascend);
// Enable atomic add for stitch nodes.
auto level = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_stitch_fusion);
pm->Add(std::make_shared<StitchAtomicCleanInsertter>(), level, is_gpu);
pm->Add(std::make_shared<StitchAtomicCleanInserter>(), level, is_gpu);
// Enable low precision
auto level_low_precision = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_low_precision);
@ -189,6 +182,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
pm->Add(std::make_shared<TsaAtomicAddToFirstTensor>(), OptLevel_1, is_gpu);
pm->Add(std::make_shared<UssAtomicAdd>(), OptLevel_1, is_gpu);
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
pm->Add(std::make_shared<OptimizeAssign>(), OptLevel_2);
pm->Add(std::make_shared<ExtendOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
pm->Add(std::make_shared<MergeOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
pm->Add(std::make_shared<EliminateRedundantOutput>(), std::min(recompute_lv, OptLevel_2));
return pm;
}

View File

@ -15,6 +15,7 @@
*/
#include "common/graph_kernel/add_atomic_clean.h"
#include <algorithm>
#include <functional>
#include <map>
@ -37,8 +38,6 @@
namespace mindspore::graphkernel {
namespace {
constexpr auto kAttrFakeOutput = "fake_output";
std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
MS_LOG(EXCEPTION) << "Expect ReduceSum node, but got " << common::AnfAlgo::GetCNodeName(node);
@ -103,21 +102,6 @@ size_t GetItemIdx(const AnfNodePtr &node) {
auto item_idx = LongToSize(GetValue<int64_t>(value_node->value()));
return item_idx;
}
CNodePtr CreateInplaceAssign(const FuncGraphPtr &sub_graph,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &parameters_infos, size_t idx) {
if (idx >= parameters_infos.size()) {
MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")";
}
MS_EXCEPTION_IF_NULL(sub_graph);
const auto &atomic_add_node = parameters_infos[idx].first.atomic_add_node;
const auto &new_parameter = parameters_infos[idx].second;
auto node = CreateCNode(
{NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node, atomic_add_node}, sub_graph,
{.format = GetFormat(atomic_add_node), .shape = GetShape(atomic_add_node), .type = GetType(atomic_add_node)});
SetNodeAttrSafely(kAttrFakeOutput, MakeValue(true), node);
return node;
}
} // namespace
std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
@ -141,19 +125,19 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
sub_graph->set_manager(mng_sub);
}
auto CheckSuitableTarget = [&mng_sub](const AtomicAddInfo &atomic_add_info) {
auto CheckSuitableTarget = [&mng_sub](const CleanZeroUserInfo &atomic_add_info) {
// Target type should not fuse any other ops in out direction, which means it should be in output list.
return mng_sub->node_users()[atomic_add_info.atomic_add_node].size() <= 1;
return mng_sub->node_users()[atomic_add_info.op_node].size() <= 1;
};
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
AtomicAddInfo atomic_add_info;
CleanZeroUserInfo atomic_add_info;
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
if (IsPrimitiveCNode(inputs[i], target_type_)) {
atomic_add_info.atomic_add_node = inputs[i]->cast<CNodePtr>();
atomic_add_info.reduce_real_output_index = i - 1;
atomic_add_info.op_node = inputs[i]->cast<CNodePtr>();
atomic_add_info.real_output_index = i - 1;
atomic_add_info.real_output_num = inputs.size() - 1;
// Target type should not fuse any other ops in out direction, which means it should be in output list.
if (!CheckSuitableTarget(atomic_add_info)) {
@ -163,7 +147,7 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
}
}
} else if (IsPrimitiveCNode(real_return_node, target_type_)) {
atomic_add_info.atomic_add_node = real_return_node->cast<CNodePtr>();
atomic_add_info.op_node = real_return_node->cast<CNodePtr>();
atomic_add_info.real_output_num = 1;
if (CheckSuitableTarget(atomic_add_info)) {
atomic_add_infos_.push_back(atomic_add_info);
@ -191,12 +175,12 @@ bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
}
// Rule 2.
if (!SuitableForAtomicAdd(atomic_add_infos_[0].atomic_add_node)) {
if (!SuitableForAtomicAdd(atomic_add_infos_[0].op_node)) {
return false;
}
// Rule 3.
return !HaveReduceInPredecessors(atomic_add_infos_[0].atomic_add_node);
return !HaveReduceInPredecessors(atomic_add_infos_[0].op_node);
}
bool AtomicAddChecker::Check(const AnfNodePtr &node) {
@ -264,167 +248,15 @@ bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
return false;
}
void AtomicCleanInsertter::CorrectKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
// Change kernel build info.
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
std::vector<std::string> &new_inputs_format = origin_inputs_format;
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
for (const auto &clean_info : clean_infos) {
auto &new_input = clean_info.second;
auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
}
auto new_selected_info = BuildSelectKernelBuildInfo(
new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &parameters_infos) {
std::map<size_t, size_t> reduce_indices;
for (size_t i = 0; i < parameters_infos.size(); ++i) {
reduce_indices[parameters_infos[i].first.reduce_real_output_index + 1] = i;
}
// Change atomic add output to InplaceAssign node.
auto output = sub_graph->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
for (size_t i = 1; i < output_cnode->inputs().size(); ++i) {
auto iter = reduce_indices.find(i);
if (iter == reduce_indices.end()) continue;
auto inplace = CreateInplaceAssign(sub_graph, parameters_infos, iter->second);
output_cnode->set_input(i, inplace);
}
} else if (parameters_infos.size() == 1) {
auto inplace = CreateInplaceAssign(sub_graph, parameters_infos, 0);
sub_graph->set_output(inplace);
}
}
void AtomicCleanInsertter::ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
mng_sub = Manage(sub_graph, false);
sub_graph->set_manager(mng_sub);
}
// Add input
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) {
// Add atomic attribute to ReduceSum node.
SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.atomic_add_node);
// add parameter
auto parameter = sub_graph->add_parameter();
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
(void)parameters_infos.emplace_back(atomic_add_info, parameter);
}
auto inputs = composite_node->cast<CNodePtr>()->inputs();
(void)std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(),
std::back_inserter(inputs),
[](const std::pair<AtomicAddInfo, AnfNodePtr> &pair_item) { return pair_item.second; });
composite_node->cast<CNodePtr>()->set_inputs(inputs);
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
}
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const AnfNodePtr &node) const {
// Insert update_state_node, need mount a monad node.
auto u = NewValueNode(kUMonad);
u->set_abstract(kUMonad->ToAbstract());
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, node};
auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
update_state_cnode->set_abstract(kUMonad->ToAbstract());
main_graph->AddNode(update_state_cnode);
return update_state_cnode;
}
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
const KernelGraphPtr &main_graph, TypeId dst_type) {
std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
MS_LOG(EXCEPTION) << "For AtomicAdd, the data type: " << TypeIdToString(dst_type, true)
<< " is not in supported list: [float16, float32, float64].";
}
// Create zero value which will be broadcast to target shape.
auto format = GetFormat(atomic_add_info.atomic_add_node);
auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
ValueNodePtr value_node;
if (dtype == kNumberTypeFloat32) {
value_node = CreateScalarTensorValueNode<float>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
static_cast<float>(0), sizeof(float));
} else {
value_node = CreateScalarTensorValueNode<double>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
static_cast<double>(0), sizeof(double));
}
// Create composite op's sub-graph.
auto new_sub_graph = std::make_shared<FuncGraph>();
AnfNodePtr broadcast_input_node;
if (dst_type == kNumberTypeFloat16) {
AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
auto cast_node_inner =
CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
broadcast_input_node = cast_node_inner;
} else {
broadcast_input_node = value_node;
}
// Create broadcast basic op.
auto dst_shape_vec = GetShape(atomic_add_info.atomic_add_node);
AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
auto broadcast_to_node_inner =
CreateCNode(atomic_clean_inputs, new_sub_graph,
{.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_info.atomic_add_node)});
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_info.atomic_add_node)), broadcast_to_node_inner);
// Makeup sub-graph.
new_sub_graph->set_output(broadcast_to_node_inner);
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
return broadcast_to_composite_node;
}
std::vector<AtomicAddUserInfo> AtomicCleanInsertter::FindOriginCNodeUsers(
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
std::vector<AtomicAddUserInfo> AtomicCleanInserter::FindOriginCNodeUsers(
const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const FuncGraphManagerPtr &mng) const {
std::vector<AtomicAddUserInfo> reduce_user_nodes;
std::map<size_t, AnfNodePtr> real_indices_and_clean_node;
for (auto &[info, clean] : info_and_broadcast_to_nodes) {
(void)real_indices_and_clean_node.emplace(info.reduce_real_output_index, clean);
(void)real_indices_and_clean_node.emplace(info.real_output_index, clean);
}
if (info_and_broadcast_to_nodes[0].first.real_output_num <= 1) {
@ -462,9 +294,9 @@ std::vector<AtomicAddUserInfo> AtomicCleanInsertter::FindOriginCNodeUsers(
return reduce_user_nodes;
}
void AtomicCleanInsertter::ProcessOriginCNodeUser(
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
void AtomicCleanInserter::ProcessOriginCNodeUser(
const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const FuncGraphManagerPtr &mng) {
// 1. Find users.
auto reduce_user_nodes = FindOriginCNodeUsers(main_graph, composite_node, info_and_broadcast_to_nodes, mng);
@ -481,24 +313,22 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(
}
}
void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos,
const FuncGraphManagerPtr &mng) {
void AtomicCleanInserter::InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<CleanZeroUserInfo> &atomic_add_infos,
const FuncGraphManagerPtr &mng) {
auto origin_composite_node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_composite_node);
// Create broadcast node.
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_broadcast_to_nodes;
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> info_and_broadcast_to_nodes;
for (auto atomic_add_info : atomic_add_infos) {
auto out_type = GetType(atomic_add_info.atomic_add_node)->cast<TensorTypePtr>();
auto out_type = GetType(atomic_add_info.op_node)->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(out_type);
auto broadcast_to_node =
CreateAtomicCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id());
auto broadcast_to_node = CreateCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id());
(void)info_and_broadcast_to_nodes.emplace_back(atomic_add_info, broadcast_to_node);
}
// Insert extra input(broadcast node output) to composite node, and make ReduceSum inplace-assign to it.
// Note: InplaceAssign outputs will increase total memory because of fake out.
ProcessOriginCNode(origin_composite_node, info_and_broadcast_to_nodes);
// Insert UpdateState + Load before origin ReduceSum's user to keep execution order.
@ -512,7 +342,7 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c
MS_LOG(INFO) << ss.str();
}
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
bool AtomicCleanInserter::Run(const FuncGraphPtr &func_graph) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();

View File

@ -24,14 +24,9 @@
#include <string>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/session/kernel_graph.h"
#include "common/graph_kernel/clean_inserter.h"
namespace mindspore::graphkernel {
struct AtomicAddInfo {
CNodePtr atomic_add_node{nullptr};
size_t reduce_real_output_index{0};
size_t real_output_num{0};
};
struct AtomicAddUserInfo {
AnfNodePtr clean_node{nullptr};
AnfNodePtr update_state_node{nullptr};
@ -46,13 +41,13 @@ class AtomicAddChecker {
static std::shared_ptr<AtomicAddChecker> Init();
bool Check(const AnfNodePtr &node);
std::vector<AtomicAddInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
std::vector<CleanZeroUserInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
protected:
virtual bool SuitableForAtomicAdd(const AnfNodePtr &) { return false; }
virtual bool FindCandidate(const AnfNodePtr &anf_node);
virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
std::vector<AtomicAddInfo> atomic_add_infos_;
std::vector<CleanZeroUserInfo> atomic_add_infos_;
PrimitivePtr target_type_{prim::kPrimReduceSum};
};
@ -74,34 +69,26 @@ class AtomicAddCheckerAscend : public AtomicAddChecker {
bool SuitableForAtomicAdd(const AnfNodePtr &node) override;
};
class AtomicCleanInsertter : public opt::Pass {
class AtomicCleanInserter : public CleanInserter {
public:
explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {}
~AtomicCleanInsertter() override = default;
explicit AtomicCleanInserter(const std::string &name = "atomic_clean") : CleanInserter(name) {}
~AtomicCleanInserter() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
protected:
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos);
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes);
virtual CNodePtr CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
const KernelGraphPtr &main_graph, TypeId dst_type);
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const AnfNodePtr &node) const;
void CreateInplaceAssignNodeAndCorrectReturn(
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &parameters_infos);
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
void InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<CleanZeroUserInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
void ProcessOriginCNodeUser(const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const FuncGraphManagerPtr &mng);
private:
std::vector<AtomicAddUserInfo> FindOriginCNodeUsers(
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
const FuncGraphManagerPtr &mng) const;
};
using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>;
using AtomicCleanInserterPtr = std::shared_ptr<AtomicCleanInserter>;
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_

View File

@ -28,8 +28,8 @@
#include "backend/common/session/kernel_graph.h"
namespace mindspore::graphkernel {
void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
void StitchAtomicCleanInserter::CorrectKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos) {
// Change kernel build info.
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
@ -56,9 +56,9 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
int index) const {
void StitchAtomicCleanInserter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
int index) const {
// Create depend node to hold execution order.
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
auto depend_cnode = main_graph->NewCNode(d_inputs);
@ -70,23 +70,22 @@ void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const
user_cnode->set_input(IntToSize(index), depend_cnode);
}
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
const AnfNodePtr &new_parameter,
const AtomicAddInfo &info) const {
// add inplaceassign
AnfNodePtr out_node = info.atomic_add_node; // Use result data itself, and set attr "fake_out" true.
auto inplace_assign_node =
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, out_node, out_node}, sub_graph,
CNodePtr StitchAtomicCleanInserter::CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
const CleanZeroUserInfo &info) const {
// add assign
AnfNodePtr out_node = info.op_node; // Use result data itself
auto assign_node =
CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, out_node}, sub_graph,
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
common::AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
return inplace_assign_node;
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), assign_node);
return assign_node;
}
void StitchAtomicCleanInsertter::ProcessOriginCNode(
void StitchAtomicCleanInserter::ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes, bool atomic_add_attr) {
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
@ -106,12 +105,12 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter, atomic_add_info);
auto assign = CreateAssignNode(sub_graph, parameter, atomic_add_info);
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after assign to avoid
// elimination.
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes =
FindInnerCNodeUsers(stitch_node_, atomic_add_info.atomic_add_node);
FindInnerCNodeUsers(stitch_node_, atomic_add_info.op_node);
bool connected = false;
for (const auto &[user_node, index] : reduce_user_nodes) {
auto user_cnode = user_node->cast<CNodePtr>();
@ -121,7 +120,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(
std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
if (!user_user.empty()) {
auto pair = user_user[0];
AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second);
AddDepend(sub_graph, user_cnode, assign, pair.first, pair.second);
}
connected = true;
}
@ -134,8 +133,8 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
}
std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) const {
std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInserter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) const {
auto node = inner_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -151,8 +150,8 @@ std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNo
return inner_user_nodes;
}
std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()};
std::pair<bool, CleanZeroUserInfo> StitchAtomicCleanInserter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, CleanZeroUserInfo()};
auto node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -163,16 +162,16 @@ std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(co
common::AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" &&
IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
AtomicAddInfo info;
info.atomic_add_node = n->cast<CNodePtr>();
CleanZeroUserInfo info;
info.op_node = n->cast<CNodePtr>();
stitch_node_ = anf_node;
return {true, info};
}
}
return {false, AtomicAddInfo()};
return {false, CleanZeroUserInfo()};
}
bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
bool StitchAtomicCleanInserter::Run(const FuncGraphPtr &func_graph) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();

View File

@ -26,31 +26,31 @@
#include "backend/common/session/kernel_graph.h"
namespace mindspore::graphkernel {
class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
class StitchAtomicCleanInserter : public AtomicCleanInserter {
public:
StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {}
~StitchAtomicCleanInsertter() override = default;
StitchAtomicCleanInserter() : AtomicCleanInserter("stitch_atomic_clean") {}
~StitchAtomicCleanInserter() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
protected:
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) override;
void ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) override;
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos) override;
void ProcessOriginCNode(const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
bool atomic_add_attr = true) override;
private:
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
const AtomicAddInfo &info) const;
CNodePtr CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
const CleanZeroUserInfo &info) const;
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
const CNodePtr &target) const;
std::pair<bool, AtomicAddInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
std::pair<bool, CleanZeroUserInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index) const;
AnfNodePtr stitch_node_{nullptr};
};
using StitchAtomicCleanInsertterPtr = std::shared_ptr<StitchAtomicCleanInsertter>;
using StitchAtomicCleanInserterPtr = std::shared_ptr<StitchAtomicCleanInserter>;
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_

View File

@ -0,0 +1,202 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/clean_inserter.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <set>
#include <string>
#include <map>
#include <vector>
#include "mindspore/core/ops/core_ops.h"
#include "ir/tensor.h"
#include "include/common/utils/utils.h"
#include "include/common/debug/anf_ir_dump.h"
#include "utils/log_adapter.h"
#include "kernel/kernel.h"
#include "kernel/common_utils.h"
#include "backend/common/session/kernel_graph.h"
#include "common/graph_kernel/graph_kernel_helper.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel {
namespace {
CNodePtr CreateAssign(const FuncGraphPtr &sub_graph,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &parameters_infos, size_t idx) {
if (idx >= parameters_infos.size()) {
MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")";
}
MS_EXCEPTION_IF_NULL(sub_graph);
const auto &target_node = parameters_infos[idx].first.op_node;
const auto &new_parameter = parameters_infos[idx].second;
auto node =
CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, target_node}, sub_graph,
{.format = GetFormat(target_node), .shape = GetShape(target_node), .type = GetType(target_node)});
return node;
}
} // namespace
void CleanInserter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos) {
// Change kernel build info.
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
std::vector<std::string> &new_inputs_format = origin_inputs_format;
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
for (const auto &clean_info : clean_infos) {
auto &new_input = clean_info.second;
auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
}
auto new_selected_info = BuildSelectKernelBuildInfo(
new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
}
void CleanInserter::CreateAssignNodeAndCorrectReturn(
const FuncGraphPtr &sub_graph, const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &parameters_infos) {
std::map<size_t, size_t> target_indices;
for (size_t i = 0; i < parameters_infos.size(); ++i) {
target_indices[parameters_infos[i].first.real_output_index + 1] = i;
}
// Change output to Assign node.
auto output = sub_graph->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
for (size_t i = 1; i < output_cnode->inputs().size(); ++i) {
auto iter = target_indices.find(i);
if (iter == target_indices.end()) continue;
auto inplace = CreateAssign(sub_graph, parameters_infos, iter->second);
output_cnode->set_input(i, inplace);
}
} else if (parameters_infos.size() == 1) {
auto inplace = CreateAssign(sub_graph, parameters_infos, 0);
sub_graph->set_output(inplace);
}
}
CNodePtr CleanInserter::InsertUpdateState(const FuncGraphPtr &main_graph, const AnfNodePtr &node) const {
// Insert update_state_node, need mount a monad node.
auto u = NewValueNode(kUMonad);
u->set_abstract(kUMonad->ToAbstract());
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, node};
auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
update_state_cnode->set_abstract(kUMonad->ToAbstract());
main_graph->AddNode(update_state_cnode);
return update_state_cnode;
}
CNodePtr CleanInserter::CreateCleanCompositeNode(const CleanZeroUserInfo &op_info, const FuncGraphPtr &main_graph,
TypeId dst_type) {
std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
MS_LOG(EXCEPTION) << "For CreateCleanCompositeNode, the data type: " << TypeIdToString(dst_type, true)
<< " is not in supported list: [float16, float32, float64].";
}
// Create zero value which will be broadcast to target shape.
auto format = GetFormat(op_info.op_node);
auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
ValueNodePtr value_node;
if (dtype == kNumberTypeFloat32) {
value_node = CreateScalarTensorValueNode<float>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
static_cast<float>(0), sizeof(float));
} else {
value_node = CreateScalarTensorValueNode<double>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
static_cast<double>(0), sizeof(double));
}
// Create composite op's sub-graph.
auto new_sub_graph = std::make_shared<FuncGraph>();
AnfNodePtr broadcast_input_node;
if (dst_type == kNumberTypeFloat16) {
AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
auto cast_node_inner =
CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
broadcast_input_node = cast_node_inner;
} else {
broadcast_input_node = value_node;
}
// Create broadcast basic op.
auto dst_shape_vec = GetShape(op_info.op_node);
AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
auto broadcast_to_node_inner = CreateCNode(
clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(op_info.op_node)});
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(op_info.op_node)), broadcast_to_node_inner);
// Makeup sub-graph.
new_sub_graph->set_output(broadcast_to_node_inner);
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
return broadcast_to_composite_node;
}
void CleanInserter::ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes, bool atomic_add_attr) {
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
mng_sub = Manage(sub_graph, false);
sub_graph->set_manager(mng_sub);
}
// Add input
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> parameters_infos;
for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) {
// Add atomic attribute to target node.
if (atomic_add_attr) SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.op_node);
// add parameter
auto parameter = sub_graph->add_parameter();
parameter->set_abstract(new_input->abstract());
parameter->set_kernel_info(new_input->kernel_info_ptr());
(void)parameters_infos.emplace_back(atomic_add_info, parameter);
}
auto inputs = composite_node->cast<CNodePtr>()->inputs();
(void)std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(),
std::back_inserter(inputs),
[](const std::pair<CleanZeroUserInfo, AnfNodePtr> &pair_item) { return pair_item.second; });
composite_node->cast<CNodePtr>()->set_inputs(inputs);
CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_
#include <utility>
#include <vector>
#include <string>
#include "backend/common/optimizer/optimizer.h"
namespace mindspore::graphkernel {
struct CleanZeroUserInfo {
CNodePtr op_node{nullptr};
size_t real_output_index{0};
size_t real_output_num{0};
};
class CleanInserter : public opt::Pass {
public:
explicit CleanInserter(const std::string &name = "clean_inserter") : Pass(name) {}
~CleanInserter() override = default;
protected:
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos);
virtual CNodePtr CreateCleanCompositeNode(const CleanZeroUserInfo &op_info, const FuncGraphPtr &main_graph,
TypeId dst_type);
CNodePtr InsertUpdateState(const FuncGraphPtr &main_graph, const AnfNodePtr &node) const;
void CreateAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &parameters_infos);
virtual void ProcessOriginCNode(
const AnfNodePtr &composite_node,
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
bool atomic_add_attr = true);
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_

View File

@ -44,7 +44,7 @@ class TsaChecker : public AtomicAddChecker {
}
for (auto atomic_add_info : atomic_add_infos_) {
auto tsa_cnode = atomic_add_info.atomic_add_node;
auto tsa_cnode = atomic_add_info.op_node;
if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
return false;
}
@ -86,7 +86,7 @@ std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::FindTsaFirstRealInputIn
}
std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode(
const KernelGraphPtr &main_graph, const AtomicAddInfo &atomic_add_info, const AnfNodePtr &node) {
const KernelGraphPtr &main_graph, const CleanZeroUserInfo &atomic_add_info, const AnfNodePtr &node) {
auto mng = main_graph->manager();
if (mng == nullptr) {
mng = Manage(main_graph, true);
@ -94,7 +94,7 @@ std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstN
}
// Find first input of tsa
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.atomic_add_node, node);
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.op_node, node);
auto users = mng->node_users()[tsa_first_input.first];
if (users.size() == 1 &&
!(utils::isa<ValueNodePtr>(tsa_first_input.first) || utils::isa<ParameterPtr>(tsa_first_input.first))) {
@ -147,7 +147,7 @@ std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstN
}
void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo(
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos) {
const AnfNodePtr &composite_node, const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_infos) {
// Change kernel build info with modify input
auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
@ -176,7 +176,7 @@ void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo(
}
void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes) {
const AnfNodePtr &composite_node, const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_nodes) {
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
auto mng_sub = sub_graph->manager();
if (mng_sub == nullptr) {
@ -185,8 +185,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
}
// Modify input
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_tsa_outers;
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> parameters_infos;
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> info_and_tsa_outers;
for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) {
composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index + 1, outer_node);
auto parameter = sub_graph->parameters()[tsa_first_input_index];
@ -194,7 +194,7 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
(void)info_and_tsa_outers.emplace_back(atomic_add_info, outer_node);
}
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
ChangeKernelBuildInfo(composite_node, outer_nodes);
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
@ -205,14 +205,14 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
}
void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos,
const std::vector<CleanZeroUserInfo> &atomic_add_infos,
const FuncGraphManagerPtr &mng) {
auto origin_composite_node = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_composite_node);
// Create identity node.
std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_outer_nodes;
std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> info_and_outer_nodes;
for (auto atomic_add_info : atomic_add_infos) {
auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node);
(void)info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second);
@ -220,7 +220,6 @@ void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, con
}
// Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplace-assign to it.
// Note: InplaceAssign outputs will increase total memory because of fake out.
ProcessOriginalCNode(origin_composite_node, info_and_outer_nodes_with_index);
// Insert UpdateState + Load before origin TensorScatterAdd's user to keep execution order.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,26 +36,26 @@ namespace mindspore::graphkernel {
* output = Reshape(input_x)
* fake_out = SubGraph'(output, indices, update) {
* %0 = TensorScatterAdd(%para1, %para2, %para3)
* %1 = InplaceAssign(%para1, %0, %0) // attrs{"fake_output":true}
* %1 = Assign(%para1, %0, umond) //
* return %1
* }
*/
class TsaAtomicAddToFirstTensor : public AtomicCleanInsertter {
class TsaAtomicAddToFirstTensor : public AtomicCleanInserter {
public:
TsaAtomicAddToFirstTensor() : AtomicCleanInsertter("tensor_scatter_add_atomic_add_to_first_tensor") {}
TsaAtomicAddToFirstTensor() : AtomicCleanInserter("tensor_scatter_add_atomic_add_to_first_tensor") {}
~TsaAtomicAddToFirstTensor() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
void ProcessOriginalCNode(const AnfNodePtr &composite_node,
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes);
const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_nodes);
void ChangeKernelBuildInfo(const AnfNodePtr &composite_node,
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos);
const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_infos);
void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
const std::vector<CleanZeroUserInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
std::pair<AnfNodePtr, size_t> GetOrCreateNewTsaFirstNode(const KernelGraphPtr &main_graph,
const AtomicAddInfo &atomic_add_info,
const CleanZeroUserInfo &atomic_add_info,
const AnfNodePtr &node);
std::pair<AnfNodePtr, size_t> FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const CNodePtr &tsa_node,
const AnfNodePtr &node);

View File

@ -36,13 +36,13 @@ namespace mindspore::graphkernel {
* output = broadcast_to(0.0) // attrs{"shape": [shape of origin output.]}
* fake_out = SubGraph'(input_x, segment_ids, output) {
* %0 = UnsortedSegmentSum(%para1, %para2)
* %1 = InplaceAssign(%para3, %0, %0) // attrs{"fake_output":true}
* %1 = Assign(%para3, %0, umond)
* return %1
* }
*/
class UssAtomicAdd : public AtomicCleanInsertter {
class UssAtomicAdd : public AtomicCleanInserter {
public:
UssAtomicAdd() : AtomicCleanInsertter("unsorted_segment_sum_atomic_add_process") {}
UssAtomicAdd() : AtomicCleanInserter("unsorted_segment_sum_atomic_add_process") {}
~UssAtomicAdd() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};