From 89e6d405245da7db16fc7db4959eb4d97a6b0e66 Mon Sep 17 00:00:00 2001 From: Yang Jiao Date: Sat, 16 Apr 2022 20:12:30 +0800 Subject: [PATCH] refactor atomic add --- .../adapter/graph_kernel_optimization.cc | 17 +- .../common/graph_kernel/add_atomic_clean.cc | 216 ++---------------- .../common/graph_kernel/add_atomic_clean.h | 41 ++-- .../add_stitch_atomic_clean_gpu.cc | 57 +++-- .../add_stitch_atomic_clean_gpu.h | 22 +- .../common/graph_kernel/clean_inserter.cc | 202 ++++++++++++++++ .../common/graph_kernel/clean_inserter.h | 51 +++++ .../tsa_atomic_add_to_first_tensor.cc | 23 +- .../tsa_atomic_add_to_first_tensor.h | 16 +- .../common/graph_kernel/uss_atomic_add.h | 6 +- 10 files changed, 359 insertions(+), 292 deletions(-) create mode 100644 mindspore/ccsrc/common/graph_kernel/clean_inserter.cc create mode 100644 mindspore/ccsrc/common/graph_kernel/clean_inserter.h diff --git a/mindspore/ccsrc/common/graph_kernel/adapter/graph_kernel_optimization.cc b/mindspore/ccsrc/common/graph_kernel/adapter/graph_kernel_optimization.cc index 50e7200fb5a..f517170dca1 100644 --- a/mindspore/ccsrc/common/graph_kernel/adapter/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/common/graph_kernel/adapter/graph_kernel_optimization.cc @@ -166,19 +166,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const { auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0); pm->Add(std::make_shared(), recompute_lv); - // Replace Assign with InplaceAssign, and replace original output with overridden parameters - pm->Add(std::make_shared(), OptLevel_2); - - pm->Add(std::make_shared(), std::min(recompute_lv, OptLevel_2)); - pm->Add(std::make_shared(), std::min(recompute_lv, OptLevel_2)); - pm->Add(std::make_shared(), std::min(recompute_lv, OptLevel_2)); - // Enable atomic add - pm->Add(std::make_shared(), OptLevel_2, is_gpu || is_ascend); + pm->Add(std::make_shared(), OptLevel_2, is_gpu || is_ascend); // Enable atomic add for stitch nodes. auto level = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_stitch_fusion); - pm->Add(std::make_shared(), level, is_gpu); + pm->Add(std::make_shared(), level, is_gpu); // Enable low precision auto level_low_precision = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_low_precision); @@ -189,6 +182,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const { pm->Add(std::make_shared(), OptLevel_1, is_gpu); pm->Add(std::make_shared(), OptLevel_1, is_gpu); + // Replace Assign with InplaceAssign, and replace original output with overridden parameters + pm->Add(std::make_shared(), OptLevel_2); + pm->Add(std::make_shared(), std::min(recompute_lv, OptLevel_2)); + pm->Add(std::make_shared(), std::min(recompute_lv, OptLevel_2)); + pm->Add(std::make_shared(), std::min(recompute_lv, OptLevel_2)); + return pm; } diff --git a/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.cc b/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.cc index 01e883962ed..8a9c31a7bd5 100644 --- a/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.cc +++ b/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.cc @@ -15,6 +15,7 @@ */ #include "common/graph_kernel/add_atomic_clean.h" + #include #include #include @@ -37,8 +38,6 @@ namespace mindspore::graphkernel { namespace { -constexpr auto kAttrFakeOutput = "fake_output"; - std::set GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) { if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) { MS_LOG(EXCEPTION) << "Expect ReduceSum node, but got " << common::AnfAlgo::GetCNodeName(node); @@ -103,21 +102,6 @@ size_t GetItemIdx(const AnfNodePtr &node) { auto item_idx = LongToSize(GetValue(value_node->value())); return item_idx; } - -CNodePtr CreateInplaceAssign(const FuncGraphPtr &sub_graph, - const std::vector> ¶meters_infos, size_t idx) { - if (idx >= parameters_infos.size()) { - MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")"; - } - MS_EXCEPTION_IF_NULL(sub_graph); - const auto &atomic_add_node = parameters_infos[idx].first.atomic_add_node; - const auto &new_parameter = parameters_infos[idx].second; - auto node = CreateCNode( - {NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node, atomic_add_node}, sub_graph, - {.format = GetFormat(atomic_add_node), .shape = GetShape(atomic_add_node), .type = GetType(atomic_add_node)}); - SetNodeAttrSafely(kAttrFakeOutput, MakeValue(true), node); - return node; -} } // namespace std::shared_ptr AtomicAddChecker::Init() { @@ -141,19 +125,19 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) { sub_graph->set_manager(mng_sub); } - auto CheckSuitableTarget = [&mng_sub](const AtomicAddInfo &atomic_add_info) { + auto CheckSuitableTarget = [&mng_sub](const CleanZeroUserInfo &atomic_add_info) { // Target type should not fuse any other ops in out direction, which means it should be in output list. - return mng_sub->node_users()[atomic_add_info.atomic_add_node].size() <= 1; + return mng_sub->node_users()[atomic_add_info.op_node].size() <= 1; }; auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex); - AtomicAddInfo atomic_add_info; + CleanZeroUserInfo atomic_add_info; if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) { const auto &inputs = real_return_node->cast()->inputs(); for (size_t i = 1; i < inputs.size(); ++i) { if (IsPrimitiveCNode(inputs[i], target_type_)) { - atomic_add_info.atomic_add_node = inputs[i]->cast(); - atomic_add_info.reduce_real_output_index = i - 1; + atomic_add_info.op_node = inputs[i]->cast(); + atomic_add_info.real_output_index = i - 1; atomic_add_info.real_output_num = inputs.size() - 1; // Target type should not fuse any other ops in out direction, which means it should be in output list. if (!CheckSuitableTarget(atomic_add_info)) { @@ -163,7 +147,7 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) { } } } else if (IsPrimitiveCNode(real_return_node, target_type_)) { - atomic_add_info.atomic_add_node = real_return_node->cast(); + atomic_add_info.op_node = real_return_node->cast(); atomic_add_info.real_output_num = 1; if (CheckSuitableTarget(atomic_add_info)) { atomic_add_infos_.push_back(atomic_add_info); @@ -191,12 +175,12 @@ bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) { } // Rule 2. - if (!SuitableForAtomicAdd(atomic_add_infos_[0].atomic_add_node)) { + if (!SuitableForAtomicAdd(atomic_add_infos_[0].op_node)) { return false; } // Rule 3. - return !HaveReduceInPredecessors(atomic_add_infos_[0].atomic_add_node); + return !HaveReduceInPredecessors(atomic_add_infos_[0].op_node); } bool AtomicAddChecker::Check(const AnfNodePtr &node) { @@ -264,167 +248,15 @@ bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) { return false; } -void AtomicCleanInsertter::CorrectKernelBuildInfo( - const AnfNodePtr &composite_node, const std::vector> &clean_infos) { - // Change kernel build info. - auto kernel_info = dynamic_cast(composite_node->kernel_info()); - MS_EXCEPTION_IF_NULL(kernel_info); - const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo(); - MS_EXCEPTION_IF_NULL(origin_kernel_build_info); - auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats(); - auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes(); - - std::vector &new_inputs_format = origin_inputs_format; - std::vector &new_inputs_type = origin_inputs_type; - for (const auto &clean_info : clean_infos) { - auto &new_input = clean_info.second; - auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0); - new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); - new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second)); - } - - auto new_selected_info = BuildSelectKernelBuildInfo( - new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(), - origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor()); - AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); -} - -void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn( - const FuncGraphPtr &sub_graph, const std::vector> ¶meters_infos) { - std::map reduce_indices; - for (size_t i = 0; i < parameters_infos.size(); ++i) { - reduce_indices[parameters_infos[i].first.reduce_real_output_index + 1] = i; - } - - // Change atomic add output to InplaceAssign node. - auto output = sub_graph->output(); - if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - for (size_t i = 1; i < output_cnode->inputs().size(); ++i) { - auto iter = reduce_indices.find(i); - if (iter == reduce_indices.end()) continue; - auto inplace = CreateInplaceAssign(sub_graph, parameters_infos, iter->second); - output_cnode->set_input(i, inplace); - } - } else if (parameters_infos.size() == 1) { - auto inplace = CreateInplaceAssign(sub_graph, parameters_infos, 0); - sub_graph->set_output(inplace); - } -} - -void AtomicCleanInsertter::ProcessOriginCNode( - const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes) { - auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node); - auto mng_sub = sub_graph->manager(); - if (mng_sub == nullptr) { - mng_sub = Manage(sub_graph, false); - sub_graph->set_manager(mng_sub); - } - - // Add input - std::vector> parameters_infos; - for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) { - // Add atomic attribute to ReduceSum node. - SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.atomic_add_node); - // add parameter - auto parameter = sub_graph->add_parameter(); - parameter->set_abstract(new_input->abstract()); - parameter->set_kernel_info(new_input->kernel_info_ptr()); - (void)parameters_infos.emplace_back(atomic_add_info, parameter); - } - - auto inputs = composite_node->cast()->inputs(); - (void)std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(), - std::back_inserter(inputs), - [](const std::pair &pair_item) { return pair_item.second; }); - composite_node->cast()->set_inputs(inputs); - - CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos); - CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes); - - auto old_graph_name = GetValue(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add"); - sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name)); - MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; -} - -CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const AnfNodePtr &node) const { - // Insert update_state_node, need mount a monad node. - auto u = NewValueNode(kUMonad); - u->set_abstract(kUMonad->ToAbstract()); - AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, node}; - auto update_state_cnode = main_graph->NewCNode(update_state_inputs); - update_state_cnode->set_abstract(kUMonad->ToAbstract()); - main_graph->AddNode(update_state_cnode); - return update_state_cnode; -} - -CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info, - const KernelGraphPtr &main_graph, TypeId dst_type) { - std::set data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; - - if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) { - MS_LOG(EXCEPTION) << "For AtomicAdd, the data type: " << TypeIdToString(dst_type, true) - << " is not in supported list: [float16, float32, float64]."; - } - - // Create zero value which will be broadcast to target shape. - auto format = GetFormat(atomic_add_info.atomic_add_node); - auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type; - ValueNodePtr value_node; - if (dtype == kNumberTypeFloat32) { - value_node = CreateScalarTensorValueNode({.format = format, .shape = {1}, .type = TypeIdToType(dtype)}, - static_cast(0), sizeof(float)); - } else { - value_node = CreateScalarTensorValueNode({.format = format, .shape = {1}, .type = TypeIdToType(dtype)}, - static_cast(0), sizeof(double)); - } - - // Create composite op's sub-graph. - auto new_sub_graph = std::make_shared(); - - AnfNodePtr broadcast_input_node; - if (dst_type == kNumberTypeFloat16) { - AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node}; - auto cast_node_inner = - CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)}); - SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner); - broadcast_input_node = cast_node_inner; - } else { - broadcast_input_node = value_node; - } - - // Create broadcast basic op. - auto dst_shape_vec = GetShape(atomic_add_info.atomic_add_node); - AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node}; - auto broadcast_to_node_inner = - CreateCNode(atomic_clean_inputs, new_sub_graph, - {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_info.atomic_add_node)}); - SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_info.atomic_add_node)), broadcast_to_node_inner); - - // Makeup sub-graph. - new_sub_graph->set_output(broadcast_to_node_inner); - auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)}); - broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract()); - SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner}); - auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); - new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); - new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean")); - - return broadcast_to_composite_node; -} - -std::vector AtomicCleanInsertter::FindOriginCNodeUsers( - const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes, +std::vector AtomicCleanInserter::FindOriginCNodeUsers( + const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, const FuncGraphManagerPtr &mng) const { std::vector reduce_user_nodes; std::map real_indices_and_clean_node; for (auto &[info, clean] : info_and_broadcast_to_nodes) { - (void)real_indices_and_clean_node.emplace(info.reduce_real_output_index, clean); + (void)real_indices_and_clean_node.emplace(info.real_output_index, clean); } if (info_and_broadcast_to_nodes[0].first.real_output_num <= 1) { @@ -462,9 +294,9 @@ std::vector AtomicCleanInsertter::FindOriginCNodeUsers( return reduce_user_nodes; } -void AtomicCleanInsertter::ProcessOriginCNodeUser( - const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes, +void AtomicCleanInserter::ProcessOriginCNodeUser( + const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, const FuncGraphManagerPtr &mng) { // 1. Find users. auto reduce_user_nodes = FindOriginCNodeUsers(main_graph, composite_node, info_and_broadcast_to_nodes, mng); @@ -481,24 +313,22 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser( } } -void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, - const std::vector &atomic_add_infos, - const FuncGraphManagerPtr &mng) { +void AtomicCleanInserter::InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node, + const std::vector &atomic_add_infos, + const FuncGraphManagerPtr &mng) { auto origin_composite_node = anf_node->cast(); MS_EXCEPTION_IF_NULL(origin_composite_node); // Create broadcast node. - std::vector> info_and_broadcast_to_nodes; + std::vector> info_and_broadcast_to_nodes; for (auto atomic_add_info : atomic_add_infos) { - auto out_type = GetType(atomic_add_info.atomic_add_node)->cast(); + auto out_type = GetType(atomic_add_info.op_node)->cast(); MS_EXCEPTION_IF_NULL(out_type); - auto broadcast_to_node = - CreateAtomicCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id()); + auto broadcast_to_node = CreateCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id()); (void)info_and_broadcast_to_nodes.emplace_back(atomic_add_info, broadcast_to_node); } // Insert extra input(broadcast node output) to composite node, and make ReduceSum inplace-assign to it. - // Note: InplaceAssign outputs will increase total memory because of fake out. ProcessOriginCNode(origin_composite_node, info_and_broadcast_to_nodes); // Insert UpdateState + Load before origin ReduceSum's user to keep execution order. @@ -512,7 +342,7 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c MS_LOG(INFO) << ss.str(); } -bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { +bool AtomicCleanInserter::Run(const FuncGraphPtr &func_graph) { auto kernel_graph = std::dynamic_pointer_cast(func_graph); MS_EXCEPTION_IF_NULL(kernel_graph); auto mng = kernel_graph->manager(); diff --git a/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.h b/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.h index d971e323ae4..b71c7a7225f 100644 --- a/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.h +++ b/mindspore/ccsrc/common/graph_kernel/add_atomic_clean.h @@ -24,14 +24,9 @@ #include #include "backend/common/optimizer/optimizer.h" #include "backend/common/session/kernel_graph.h" +#include "common/graph_kernel/clean_inserter.h" namespace mindspore::graphkernel { -struct AtomicAddInfo { - CNodePtr atomic_add_node{nullptr}; - size_t reduce_real_output_index{0}; - size_t real_output_num{0}; -}; - struct AtomicAddUserInfo { AnfNodePtr clean_node{nullptr}; AnfNodePtr update_state_node{nullptr}; @@ -46,13 +41,13 @@ class AtomicAddChecker { static std::shared_ptr Init(); bool Check(const AnfNodePtr &node); - std::vector GetAtomicAddInfo() { return atomic_add_infos_; } + std::vector GetAtomicAddInfo() { return atomic_add_infos_; } protected: virtual bool SuitableForAtomicAdd(const AnfNodePtr &) { return false; } virtual bool FindCandidate(const AnfNodePtr &anf_node); virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); - std::vector atomic_add_infos_; + std::vector atomic_add_infos_; PrimitivePtr target_type_{prim::kPrimReduceSum}; }; @@ -74,34 +69,26 @@ class AtomicAddCheckerAscend : public AtomicAddChecker { bool SuitableForAtomicAdd(const AnfNodePtr &node) override; }; -class AtomicCleanInsertter : public opt::Pass { +class AtomicCleanInserter : public CleanInserter { public: - explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {} - ~AtomicCleanInsertter() override = default; + explicit AtomicCleanInserter(const std::string &name = "atomic_clean") : CleanInserter(name) {} + ~AtomicCleanInserter() override = default; bool Run(const FuncGraphPtr &func_graph) override; protected: - virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, - const std::vector> &clean_infos); - virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes); - virtual CNodePtr CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info, - const KernelGraphPtr &main_graph, TypeId dst_type); - void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, - const std::vector &atomic_add_infos, const FuncGraphManagerPtr &mng); - CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const AnfNodePtr &node) const; - void CreateInplaceAssignNodeAndCorrectReturn( - const FuncGraphPtr &sub_graph, const std::vector> ¶meters_infos); - void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes, + void InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node, + const std::vector &atomic_add_infos, const FuncGraphManagerPtr &mng); + + void ProcessOriginCNodeUser(const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, const FuncGraphManagerPtr &mng); private: std::vector FindOriginCNodeUsers( - const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes, + const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, const FuncGraphManagerPtr &mng) const; }; -using AtomicCleanInsertterPtr = std::shared_ptr; +using AtomicCleanInserterPtr = std::shared_ptr; } // namespace mindspore::graphkernel #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_ diff --git a/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.cc b/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.cc index dc98a3af049..481c6df054e 100644 --- a/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.cc @@ -28,8 +28,8 @@ #include "backend/common/session/kernel_graph.h" namespace mindspore::graphkernel { -void StitchAtomicCleanInsertter::CorrectKernelBuildInfo( - const AnfNodePtr &composite_node, const std::vector> &clean_infos) { +void StitchAtomicCleanInserter::CorrectKernelBuildInfo( + const AnfNodePtr &composite_node, const std::vector> &clean_infos) { // Change kernel build info. auto kernel_info = dynamic_cast(composite_node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); @@ -56,9 +56,9 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo( AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); } -void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, - const AnfNodePtr &composite_node, const AnfNodePtr &user_node, - int index) const { +void StitchAtomicCleanInserter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, + const AnfNodePtr &composite_node, const AnfNodePtr &user_node, + int index) const { // Create depend node to hold execution order. AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; auto depend_cnode = main_graph->NewCNode(d_inputs); @@ -70,23 +70,22 @@ void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const user_cnode->set_input(IntToSize(index), depend_cnode); } -CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, - const AnfNodePtr &new_parameter, - const AtomicAddInfo &info) const { - // add inplaceassign - AnfNodePtr out_node = info.atomic_add_node; // Use result data itself, and set attr "fake_out" true. - auto inplace_assign_node = - CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, out_node, out_node}, sub_graph, +CNodePtr StitchAtomicCleanInserter::CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter, + const CleanZeroUserInfo &info) const { + // add assign + AnfNodePtr out_node = info.op_node; // Use result data itself + + auto assign_node = + CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, out_node}, sub_graph, {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)}); - SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node); common::AnfAlgo::EraseNodeAttr(kAttrStitch, out_node); - SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node); - return inplace_assign_node; + SetNodeAttrSafely(kAttrStitch, MakeValue("common"), assign_node); + return assign_node; } -void StitchAtomicCleanInsertter::ProcessOriginCNode( +void StitchAtomicCleanInserter::ProcessOriginCNode( const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes) { + const std::vector> &info_and_broadcast_to_nodes, bool atomic_add_attr) { auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node); auto mng_sub = sub_graph->manager(); if (mng_sub == nullptr) { @@ -106,12 +105,12 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode( parameter->set_abstract(new_input->abstract()); parameter->set_kernel_info(new_input->kernel_info_ptr()); - auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter, atomic_add_info); + auto assign = CreateAssignNode(sub_graph, parameter, atomic_add_info); - // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid + // Replace atomic ReduceSum's user with atomic clean output, and add depend op after assign to avoid // elimination. std::vector> reduce_user_nodes = - FindInnerCNodeUsers(stitch_node_, atomic_add_info.atomic_add_node); + FindInnerCNodeUsers(stitch_node_, atomic_add_info.op_node); bool connected = false; for (const auto &[user_node, index] : reduce_user_nodes) { auto user_cnode = user_node->cast(); @@ -121,7 +120,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode( std::vector> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode); if (!user_user.empty()) { auto pair = user_user[0]; - AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second); + AddDepend(sub_graph, user_cnode, assign, pair.first, pair.second); } connected = true; } @@ -134,8 +133,8 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode( MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name; } -std::vector> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, - const CNodePtr &target) const { +std::vector> StitchAtomicCleanInserter::FindInnerCNodeUsers(const AnfNodePtr &inner_node, + const CNodePtr &target) const { auto node = inner_node->cast(); MS_EXCEPTION_IF_NULL(node); auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node); @@ -151,8 +150,8 @@ std::vector> StitchAtomicCleanInsertter::FindInnerCNo return inner_user_nodes; } -std::pair StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) { - if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()}; +std::pair StitchAtomicCleanInserter::IsStitchWithAtomic(const AnfNodePtr &anf_node) { + if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, CleanZeroUserInfo()}; auto node = anf_node->cast(); MS_EXCEPTION_IF_NULL(node); auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node); @@ -163,16 +162,16 @@ std::pair StitchAtomicCleanInsertter::IsStitchWithAtomic(co common::AnfAlgo::GetNodeAttr(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) { MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!"; - AtomicAddInfo info; - info.atomic_add_node = n->cast(); + CleanZeroUserInfo info; + info.op_node = n->cast(); stitch_node_ = anf_node; return {true, info}; } } - return {false, AtomicAddInfo()}; + return {false, CleanZeroUserInfo()}; } -bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { +bool StitchAtomicCleanInserter::Run(const FuncGraphPtr &func_graph) { auto kernel_graph = std::dynamic_pointer_cast(func_graph); MS_EXCEPTION_IF_NULL(kernel_graph); auto mng = kernel_graph->manager(); diff --git a/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.h b/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.h index cb2c260e4e0..538adf0422f 100644 --- a/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.h +++ b/mindspore/ccsrc/common/graph_kernel/add_stitch_atomic_clean_gpu.h @@ -26,31 +26,31 @@ #include "backend/common/session/kernel_graph.h" namespace mindspore::graphkernel { -class StitchAtomicCleanInsertter : public AtomicCleanInsertter { +class StitchAtomicCleanInserter : public AtomicCleanInserter { public: - StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {} - ~StitchAtomicCleanInsertter() override = default; + StitchAtomicCleanInserter() : AtomicCleanInserter("stitch_atomic_clean") {} + ~StitchAtomicCleanInserter() override = default; bool Run(const FuncGraphPtr &func_graph) override; protected: void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, - const std::vector> &clean_infos) override; - void ProcessOriginCNode( - const AnfNodePtr &composite_node, - const std::vector> &info_and_broadcast_to_nodes) override; + const std::vector> &clean_infos) override; + void ProcessOriginCNode(const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, + bool atomic_add_attr = true) override; private: - CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter, - const AtomicAddInfo &info) const; + CNodePtr CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter, + const CleanZeroUserInfo &info) const; std::vector> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target) const; - std::pair IsStitchWithAtomic(const AnfNodePtr &anf_node); + std::pair IsStitchWithAtomic(const AnfNodePtr &anf_node); void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) const; AnfNodePtr stitch_node_{nullptr}; }; -using StitchAtomicCleanInsertterPtr = std::shared_ptr; +using StitchAtomicCleanInserterPtr = std::shared_ptr; } // namespace mindspore::graphkernel #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_ diff --git a/mindspore/ccsrc/common/graph_kernel/clean_inserter.cc b/mindspore/ccsrc/common/graph_kernel/clean_inserter.cc new file mode 100644 index 00000000000..84fdf36ea91 --- /dev/null +++ b/mindspore/ccsrc/common/graph_kernel/clean_inserter.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/graph_kernel/clean_inserter.h" + +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/core_ops.h" +#include "ir/tensor.h" +#include "include/common/utils/utils.h" +#include "include/common/debug/anf_ir_dump.h" +#include "utils/log_adapter.h" +#include "kernel/kernel.h" +#include "kernel/common_utils.h" +#include "backend/common/session/kernel_graph.h" +#include "common/graph_kernel/graph_kernel_helper.h" +#include "common/graph_kernel/core/graph_kernel_utils.h" + +namespace mindspore::graphkernel { +namespace { +CNodePtr CreateAssign(const FuncGraphPtr &sub_graph, + const std::vector> ¶meters_infos, size_t idx) { + if (idx >= parameters_infos.size()) { + MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")"; + } + MS_EXCEPTION_IF_NULL(sub_graph); + + const auto &target_node = parameters_infos[idx].first.op_node; + const auto &new_parameter = parameters_infos[idx].second; + + auto node = + CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, target_node}, sub_graph, + {.format = GetFormat(target_node), .shape = GetShape(target_node), .type = GetType(target_node)}); + return node; +} +} // namespace + +void CleanInserter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, + const std::vector> &clean_infos) { + // Change kernel build info. + auto kernel_info = dynamic_cast(composite_node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo(); + MS_EXCEPTION_IF_NULL(origin_kernel_build_info); + auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats(); + auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes(); + + std::vector &new_inputs_format = origin_inputs_format; + std::vector &new_inputs_type = origin_inputs_type; + for (const auto &clean_info : clean_infos) { + auto &new_input = clean_info.second; + auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0); + new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); + new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second)); + } + + auto new_selected_info = BuildSelectKernelBuildInfo( + new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(), + origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor()); + AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get()); +} + +void CleanInserter::CreateAssignNodeAndCorrectReturn( + const FuncGraphPtr &sub_graph, const std::vector> ¶meters_infos) { + std::map target_indices; + for (size_t i = 0; i < parameters_infos.size(); ++i) { + target_indices[parameters_infos[i].first.real_output_index + 1] = i; + } + + // Change output to Assign node. + auto output = sub_graph->output(); + if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + for (size_t i = 1; i < output_cnode->inputs().size(); ++i) { + auto iter = target_indices.find(i); + if (iter == target_indices.end()) continue; + auto inplace = CreateAssign(sub_graph, parameters_infos, iter->second); + output_cnode->set_input(i, inplace); + } + } else if (parameters_infos.size() == 1) { + auto inplace = CreateAssign(sub_graph, parameters_infos, 0); + sub_graph->set_output(inplace); + } +} + +CNodePtr CleanInserter::InsertUpdateState(const FuncGraphPtr &main_graph, const AnfNodePtr &node) const { + // Insert update_state_node, need mount a monad node. + auto u = NewValueNode(kUMonad); + u->set_abstract(kUMonad->ToAbstract()); + AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, node}; + auto update_state_cnode = main_graph->NewCNode(update_state_inputs); + update_state_cnode->set_abstract(kUMonad->ToAbstract()); + main_graph->AddNode(update_state_cnode); + return update_state_cnode; +} + +CNodePtr CleanInserter::CreateCleanCompositeNode(const CleanZeroUserInfo &op_info, const FuncGraphPtr &main_graph, + TypeId dst_type) { + std::set data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; + + if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) { + MS_LOG(EXCEPTION) << "For CreateCleanCompositeNode, the data type: " << TypeIdToString(dst_type, true) + << " is not in supported list: [float16, float32, float64]."; + } + + // Create zero value which will be broadcast to target shape. + auto format = GetFormat(op_info.op_node); + auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type; + ValueNodePtr value_node; + if (dtype == kNumberTypeFloat32) { + value_node = CreateScalarTensorValueNode({.format = format, .shape = {1}, .type = TypeIdToType(dtype)}, + static_cast(0), sizeof(float)); + } else { + value_node = CreateScalarTensorValueNode({.format = format, .shape = {1}, .type = TypeIdToType(dtype)}, + static_cast(0), sizeof(double)); + } + + // Create composite op's sub-graph. + auto new_sub_graph = std::make_shared(); + + AnfNodePtr broadcast_input_node; + if (dst_type == kNumberTypeFloat16) { + AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node}; + auto cast_node_inner = + CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)}); + SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner); + broadcast_input_node = cast_node_inner; + } else { + broadcast_input_node = value_node; + } + + // Create broadcast basic op. + auto dst_shape_vec = GetShape(op_info.op_node); + AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node}; + auto broadcast_to_node_inner = CreateCNode( + clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(op_info.op_node)}); + SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(op_info.op_node)), broadcast_to_node_inner); + + // Makeup sub-graph. + new_sub_graph->set_output(broadcast_to_node_inner); + auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)}); + broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract()); + SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner}); + auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean"); + new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr)); + new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean")); + + return broadcast_to_composite_node; +} + +void CleanInserter::ProcessOriginCNode( + const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, bool atomic_add_attr) { + auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node); + auto mng_sub = sub_graph->manager(); + if (mng_sub == nullptr) { + mng_sub = Manage(sub_graph, false); + sub_graph->set_manager(mng_sub); + } + + // Add input + std::vector> parameters_infos; + for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) { + // Add atomic attribute to target node. + if (atomic_add_attr) SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.op_node); + + // add parameter + auto parameter = sub_graph->add_parameter(); + parameter->set_abstract(new_input->abstract()); + parameter->set_kernel_info(new_input->kernel_info_ptr()); + (void)parameters_infos.emplace_back(atomic_add_info, parameter); + } + + auto inputs = composite_node->cast()->inputs(); + (void)std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(), + std::back_inserter(inputs), + [](const std::pair &pair_item) { return pair_item.second; }); + composite_node->cast()->set_inputs(inputs); + + CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos); + CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes); +} +} // namespace mindspore::graphkernel diff --git a/mindspore/ccsrc/common/graph_kernel/clean_inserter.h b/mindspore/ccsrc/common/graph_kernel/clean_inserter.h new file mode 100644 index 00000000000..323c9b7f26b --- /dev/null +++ b/mindspore/ccsrc/common/graph_kernel/clean_inserter.h @@ -0,0 +1,51 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_ + +#include +#include +#include +#include "backend/common/optimizer/optimizer.h" + +namespace mindspore::graphkernel { +struct CleanZeroUserInfo { + CNodePtr op_node{nullptr}; + size_t real_output_index{0}; + size_t real_output_num{0}; +}; + +class CleanInserter : public opt::Pass { + public: + explicit CleanInserter(const std::string &name = "clean_inserter") : Pass(name) {} + ~CleanInserter() override = default; + + protected: + virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, + const std::vector> &clean_infos); + virtual CNodePtr CreateCleanCompositeNode(const CleanZeroUserInfo &op_info, const FuncGraphPtr &main_graph, + TypeId dst_type); + CNodePtr InsertUpdateState(const FuncGraphPtr &main_graph, const AnfNodePtr &node) const; + void CreateAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, + const std::vector> ¶meters_infos); + virtual void ProcessOriginCNode( + const AnfNodePtr &composite_node, + const std::vector> &info_and_broadcast_to_nodes, + bool atomic_add_attr = true); +}; +} // namespace mindspore::graphkernel +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_ diff --git a/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.cc b/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.cc index 60df48a578e..9b8a47b94f5 100644 --- a/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.cc +++ b/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.cc @@ -44,7 +44,7 @@ class TsaChecker : public AtomicAddChecker { } for (auto atomic_add_info : atomic_add_infos_) { - auto tsa_cnode = atomic_add_info.atomic_add_node; + auto tsa_cnode = atomic_add_info.op_node; if (!utils::isa(tsa_cnode->input(1))) { return false; } @@ -86,7 +86,7 @@ std::pair TsaAtomicAddToFirstTensor::FindTsaFirstRealInputIn } std::pair TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode( - const KernelGraphPtr &main_graph, const AtomicAddInfo &atomic_add_info, const AnfNodePtr &node) { + const KernelGraphPtr &main_graph, const CleanZeroUserInfo &atomic_add_info, const AnfNodePtr &node) { auto mng = main_graph->manager(); if (mng == nullptr) { mng = Manage(main_graph, true); @@ -94,7 +94,7 @@ std::pair TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstN } // Find first input of tsa - auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.atomic_add_node, node); + auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.op_node, node); auto users = mng->node_users()[tsa_first_input.first]; if (users.size() == 1 && !(utils::isa(tsa_first_input.first) || utils::isa(tsa_first_input.first))) { @@ -147,7 +147,7 @@ std::pair TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstN } void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo( - const AnfNodePtr &composite_node, const std::vector> &outer_infos) { + const AnfNodePtr &composite_node, const std::vector> &outer_infos) { // Change kernel build info with modify input auto kernel_info = static_cast(composite_node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); @@ -176,7 +176,7 @@ void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo( } void TsaAtomicAddToFirstTensor::ProcessOriginalCNode( - const AnfNodePtr &composite_node, const std::vector> &outer_nodes) { + const AnfNodePtr &composite_node, const std::vector> &outer_nodes) { auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node); auto mng_sub = sub_graph->manager(); if (mng_sub == nullptr) { @@ -185,8 +185,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode( } // Modify input - std::vector> parameters_infos; - std::vector> info_and_tsa_outers; + std::vector> parameters_infos; + std::vector> info_and_tsa_outers; for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) { composite_node->cast()->set_input(tsa_first_input_index + 1, outer_node); auto parameter = sub_graph->parameters()[tsa_first_input_index]; @@ -194,7 +194,7 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode( (void)info_and_tsa_outers.emplace_back(atomic_add_info, outer_node); } - CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos); + CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos); ChangeKernelBuildInfo(composite_node, outer_nodes); auto old_graph_name = GetValue(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); @@ -205,14 +205,14 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode( } void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, - const std::vector &atomic_add_infos, + const std::vector &atomic_add_infos, const FuncGraphManagerPtr &mng) { auto origin_composite_node = anf_node->cast(); MS_EXCEPTION_IF_NULL(origin_composite_node); // Create identity node. - std::vector> info_and_outer_nodes_with_index; - std::vector> info_and_outer_nodes; + std::vector> info_and_outer_nodes_with_index; + std::vector> info_and_outer_nodes; for (auto atomic_add_info : atomic_add_infos) { auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node); (void)info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second); @@ -220,7 +220,6 @@ void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, con } // Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplace-assign to it. - // Note: InplaceAssign outputs will increase total memory because of fake out. ProcessOriginalCNode(origin_composite_node, info_and_outer_nodes_with_index); // Insert UpdateState + Load before origin TensorScatterAdd's user to keep execution order. diff --git a/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.h b/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.h index 4010b3826d3..0ad6a9410d0 100644 --- a/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.h +++ b/mindspore/ccsrc/common/graph_kernel/tsa_atomic_add_to_first_tensor.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,26 +36,26 @@ namespace mindspore::graphkernel { * output = Reshape(input_x) * fake_out = SubGraph'(output, indices, update) { * %0 = TensorScatterAdd(%para1, %para2, %para3) - * %1 = InplaceAssign(%para1, %0, %0) // attrs{"fake_output":true} + * %1 = Assign(%para1, %0, umond) // * return %1 * } */ -class TsaAtomicAddToFirstTensor : public AtomicCleanInsertter { +class TsaAtomicAddToFirstTensor : public AtomicCleanInserter { public: - TsaAtomicAddToFirstTensor() : AtomicCleanInsertter("tensor_scatter_add_atomic_add_to_first_tensor") {} + TsaAtomicAddToFirstTensor() : AtomicCleanInserter("tensor_scatter_add_atomic_add_to_first_tensor") {} ~TsaAtomicAddToFirstTensor() override = default; bool Run(const FuncGraphPtr &func_graph) override; private: void ProcessOriginalCNode(const AnfNodePtr &composite_node, - const std::vector> &outer_nodes); + const std::vector> &outer_nodes); void ChangeKernelBuildInfo(const AnfNodePtr &composite_node, - const std::vector> &outer_infos); + const std::vector> &outer_infos); void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, - const std::vector &atomic_add_infos, const FuncGraphManagerPtr &mng); + const std::vector &atomic_add_infos, const FuncGraphManagerPtr &mng); std::pair GetOrCreateNewTsaFirstNode(const KernelGraphPtr &main_graph, - const AtomicAddInfo &atomic_add_info, + const CleanZeroUserInfo &atomic_add_info, const AnfNodePtr &node); std::pair FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const CNodePtr &tsa_node, const AnfNodePtr &node); diff --git a/mindspore/ccsrc/common/graph_kernel/uss_atomic_add.h b/mindspore/ccsrc/common/graph_kernel/uss_atomic_add.h index 9f63347efd8..99ddaaffc03 100644 --- a/mindspore/ccsrc/common/graph_kernel/uss_atomic_add.h +++ b/mindspore/ccsrc/common/graph_kernel/uss_atomic_add.h @@ -36,13 +36,13 @@ namespace mindspore::graphkernel { * output = broadcast_to(0.0) // attrs{"shape": [shape of origin output.]} * fake_out = SubGraph'(input_x, segment_ids, output) { * %0 = UnsortedSegmentSum(%para1, %para2) - * %1 = InplaceAssign(%para3, %0, %0) // attrs{"fake_output":true} + * %1 = Assign(%para3, %0, umond) * return %1 * } */ -class UssAtomicAdd : public AtomicCleanInsertter { +class UssAtomicAdd : public AtomicCleanInserter { public: - UssAtomicAdd() : AtomicCleanInsertter("unsorted_segment_sum_atomic_add_process") {} + UssAtomicAdd() : AtomicCleanInserter("unsorted_segment_sum_atomic_add_process") {} ~UssAtomicAdd() override = default; bool Run(const FuncGraphPtr &func_graph) override; };