forked from mindspore-Ecosystem/mindspore
refactor atomic add
This commit is contained in:
parent
2174f71616
commit
89e6d40524
|
@ -166,19 +166,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
|||
auto recompute_lv = GetPassLevelByFlag(flags.recompute_increment_threshold > 0 || flags.recompute_peak_threshold > 0);
|
||||
pm->Add(std::make_shared<GraphKernelRecompute>(), recompute_lv);
|
||||
|
||||
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
|
||||
pm->Add(std::make_shared<OptimizeAssign>(), OptLevel_2);
|
||||
|
||||
pm->Add(std::make_shared<ExtendOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
|
||||
pm->Add(std::make_shared<MergeOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
|
||||
pm->Add(std::make_shared<EliminateRedundantOutput>(), std::min(recompute_lv, OptLevel_2));
|
||||
|
||||
// Enable atomic add
|
||||
pm->Add(std::make_shared<AtomicCleanInsertter>(), OptLevel_2, is_gpu || is_ascend);
|
||||
pm->Add(std::make_shared<AtomicCleanInserter>(), OptLevel_2, is_gpu || is_ascend);
|
||||
|
||||
// Enable atomic add for stitch nodes.
|
||||
auto level = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_stitch_fusion);
|
||||
pm->Add(std::make_shared<StitchAtomicCleanInsertter>(), level, is_gpu);
|
||||
pm->Add(std::make_shared<StitchAtomicCleanInserter>(), level, is_gpu);
|
||||
|
||||
// Enable low precision
|
||||
auto level_low_precision = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_low_precision);
|
||||
|
@ -189,6 +182,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
|||
pm->Add(std::make_shared<TsaAtomicAddToFirstTensor>(), OptLevel_1, is_gpu);
|
||||
pm->Add(std::make_shared<UssAtomicAdd>(), OptLevel_1, is_gpu);
|
||||
|
||||
// Replace Assign with InplaceAssign, and replace original output with overridden parameters
|
||||
pm->Add(std::make_shared<OptimizeAssign>(), OptLevel_2);
|
||||
pm->Add(std::make_shared<ExtendOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
|
||||
pm->Add(std::make_shared<MergeOutputForUpdateState>(), std::min(recompute_lv, OptLevel_2));
|
||||
pm->Add(std::make_shared<EliminateRedundantOutput>(), std::min(recompute_lv, OptLevel_2));
|
||||
|
||||
return pm;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "common/graph_kernel/add_atomic_clean.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
@ -37,8 +38,6 @@
|
|||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
constexpr auto kAttrFakeOutput = "fake_output";
|
||||
|
||||
std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
|
||||
MS_LOG(EXCEPTION) << "Expect ReduceSum node, but got " << common::AnfAlgo::GetCNodeName(node);
|
||||
|
@ -103,21 +102,6 @@ size_t GetItemIdx(const AnfNodePtr &node) {
|
|||
auto item_idx = LongToSize(GetValue<int64_t>(value_node->value()));
|
||||
return item_idx;
|
||||
}
|
||||
|
||||
CNodePtr CreateInplaceAssign(const FuncGraphPtr &sub_graph,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> ¶meters_infos, size_t idx) {
|
||||
if (idx >= parameters_infos.size()) {
|
||||
MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
const auto &atomic_add_node = parameters_infos[idx].first.atomic_add_node;
|
||||
const auto &new_parameter = parameters_infos[idx].second;
|
||||
auto node = CreateCNode(
|
||||
{NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node, atomic_add_node}, sub_graph,
|
||||
{.format = GetFormat(atomic_add_node), .shape = GetShape(atomic_add_node), .type = GetType(atomic_add_node)});
|
||||
SetNodeAttrSafely(kAttrFakeOutput, MakeValue(true), node);
|
||||
return node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<AtomicAddChecker> AtomicAddChecker::Init() {
|
||||
|
@ -141,19 +125,19 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
|
|||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
auto CheckSuitableTarget = [&mng_sub](const AtomicAddInfo &atomic_add_info) {
|
||||
auto CheckSuitableTarget = [&mng_sub](const CleanZeroUserInfo &atomic_add_info) {
|
||||
// Target type should not fuse any other ops in out direction, which means it should be in output list.
|
||||
return mng_sub->node_users()[atomic_add_info.atomic_add_node].size() <= 1;
|
||||
return mng_sub->node_users()[atomic_add_info.op_node].size() <= 1;
|
||||
};
|
||||
|
||||
auto real_return_node = sub_graph->get_return()->input(kFirstDataInputIndex);
|
||||
AtomicAddInfo atomic_add_info;
|
||||
CleanZeroUserInfo atomic_add_info;
|
||||
if (IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
|
||||
const auto &inputs = real_return_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (IsPrimitiveCNode(inputs[i], target_type_)) {
|
||||
atomic_add_info.atomic_add_node = inputs[i]->cast<CNodePtr>();
|
||||
atomic_add_info.reduce_real_output_index = i - 1;
|
||||
atomic_add_info.op_node = inputs[i]->cast<CNodePtr>();
|
||||
atomic_add_info.real_output_index = i - 1;
|
||||
atomic_add_info.real_output_num = inputs.size() - 1;
|
||||
// Target type should not fuse any other ops in out direction, which means it should be in output list.
|
||||
if (!CheckSuitableTarget(atomic_add_info)) {
|
||||
|
@ -163,7 +147,7 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
|
|||
}
|
||||
}
|
||||
} else if (IsPrimitiveCNode(real_return_node, target_type_)) {
|
||||
atomic_add_info.atomic_add_node = real_return_node->cast<CNodePtr>();
|
||||
atomic_add_info.op_node = real_return_node->cast<CNodePtr>();
|
||||
atomic_add_info.real_output_num = 1;
|
||||
if (CheckSuitableTarget(atomic_add_info)) {
|
||||
atomic_add_infos_.push_back(atomic_add_info);
|
||||
|
@ -191,12 +175,12 @@ bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
|
|||
}
|
||||
|
||||
// Rule 2.
|
||||
if (!SuitableForAtomicAdd(atomic_add_infos_[0].atomic_add_node)) {
|
||||
if (!SuitableForAtomicAdd(atomic_add_infos_[0].op_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rule 3.
|
||||
return !HaveReduceInPredecessors(atomic_add_infos_[0].atomic_add_node);
|
||||
return !HaveReduceInPredecessors(atomic_add_infos_[0].op_node);
|
||||
}
|
||||
|
||||
bool AtomicAddChecker::Check(const AnfNodePtr &node) {
|
||||
|
@ -264,167 +248,15 @@ bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::CorrectKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
|
||||
// Change kernel build info.
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
|
||||
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
|
||||
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
|
||||
|
||||
std::vector<std::string> &new_inputs_format = origin_inputs_format;
|
||||
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
|
||||
for (const auto &clean_info : clean_infos) {
|
||||
auto &new_input = clean_info.second;
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
}
|
||||
|
||||
auto new_selected_info = BuildSelectKernelBuildInfo(
|
||||
new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
|
||||
origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(
|
||||
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> ¶meters_infos) {
|
||||
std::map<size_t, size_t> reduce_indices;
|
||||
for (size_t i = 0; i < parameters_infos.size(); ++i) {
|
||||
reduce_indices[parameters_infos[i].first.reduce_real_output_index + 1] = i;
|
||||
}
|
||||
|
||||
// Change atomic add output to InplaceAssign node.
|
||||
auto output = sub_graph->output();
|
||||
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
for (size_t i = 1; i < output_cnode->inputs().size(); ++i) {
|
||||
auto iter = reduce_indices.find(i);
|
||||
if (iter == reduce_indices.end()) continue;
|
||||
auto inplace = CreateInplaceAssign(sub_graph, parameters_infos, iter->second);
|
||||
output_cnode->set_input(i, inplace);
|
||||
}
|
||||
} else if (parameters_infos.size() == 1) {
|
||||
auto inplace = CreateInplaceAssign(sub_graph, parameters_infos, 0);
|
||||
sub_graph->set_output(inplace);
|
||||
}
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
// Add input
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
|
||||
for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) {
|
||||
// Add atomic attribute to ReduceSum node.
|
||||
SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.atomic_add_node);
|
||||
// add parameter
|
||||
auto parameter = sub_graph->add_parameter();
|
||||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
(void)parameters_infos.emplace_back(atomic_add_info, parameter);
|
||||
}
|
||||
|
||||
auto inputs = composite_node->cast<CNodePtr>()->inputs();
|
||||
(void)std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(),
|
||||
std::back_inserter(inputs),
|
||||
[](const std::pair<AtomicAddInfo, AnfNodePtr> &pair_item) { return pair_item.second; });
|
||||
composite_node->cast<CNodePtr>()->set_inputs(inputs);
|
||||
|
||||
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
|
||||
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
|
||||
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
|
||||
}
|
||||
|
||||
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const AnfNodePtr &node) const {
|
||||
// Insert update_state_node, need mount a monad node.
|
||||
auto u = NewValueNode(kUMonad);
|
||||
u->set_abstract(kUMonad->ToAbstract());
|
||||
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, node};
|
||||
auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
|
||||
update_state_cnode->set_abstract(kUMonad->ToAbstract());
|
||||
main_graph->AddNode(update_state_cnode);
|
||||
return update_state_cnode;
|
||||
}
|
||||
|
||||
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
|
||||
const KernelGraphPtr &main_graph, TypeId dst_type) {
|
||||
std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
|
||||
if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
|
||||
MS_LOG(EXCEPTION) << "For AtomicAdd, the data type: " << TypeIdToString(dst_type, true)
|
||||
<< " is not in supported list: [float16, float32, float64].";
|
||||
}
|
||||
|
||||
// Create zero value which will be broadcast to target shape.
|
||||
auto format = GetFormat(atomic_add_info.atomic_add_node);
|
||||
auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
|
||||
ValueNodePtr value_node;
|
||||
if (dtype == kNumberTypeFloat32) {
|
||||
value_node = CreateScalarTensorValueNode<float>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
|
||||
static_cast<float>(0), sizeof(float));
|
||||
} else {
|
||||
value_node = CreateScalarTensorValueNode<double>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
|
||||
static_cast<double>(0), sizeof(double));
|
||||
}
|
||||
|
||||
// Create composite op's sub-graph.
|
||||
auto new_sub_graph = std::make_shared<FuncGraph>();
|
||||
|
||||
AnfNodePtr broadcast_input_node;
|
||||
if (dst_type == kNumberTypeFloat16) {
|
||||
AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
|
||||
auto cast_node_inner =
|
||||
CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
|
||||
SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
|
||||
broadcast_input_node = cast_node_inner;
|
||||
} else {
|
||||
broadcast_input_node = value_node;
|
||||
}
|
||||
|
||||
// Create broadcast basic op.
|
||||
auto dst_shape_vec = GetShape(atomic_add_info.atomic_add_node);
|
||||
AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
|
||||
auto broadcast_to_node_inner =
|
||||
CreateCNode(atomic_clean_inputs, new_sub_graph,
|
||||
{.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_info.atomic_add_node)});
|
||||
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_info.atomic_add_node)), broadcast_to_node_inner);
|
||||
|
||||
// Makeup sub-graph.
|
||||
new_sub_graph->set_output(broadcast_to_node_inner);
|
||||
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
|
||||
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
|
||||
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
|
||||
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
|
||||
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
|
||||
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
|
||||
|
||||
return broadcast_to_composite_node;
|
||||
}
|
||||
|
||||
std::vector<AtomicAddUserInfo> AtomicCleanInsertter::FindOriginCNodeUsers(
|
||||
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
std::vector<AtomicAddUserInfo> AtomicCleanInserter::FindOriginCNodeUsers(
|
||||
const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const FuncGraphManagerPtr &mng) const {
|
||||
std::vector<AtomicAddUserInfo> reduce_user_nodes;
|
||||
|
||||
std::map<size_t, AnfNodePtr> real_indices_and_clean_node;
|
||||
for (auto &[info, clean] : info_and_broadcast_to_nodes) {
|
||||
(void)real_indices_and_clean_node.emplace(info.reduce_real_output_index, clean);
|
||||
(void)real_indices_and_clean_node.emplace(info.real_output_index, clean);
|
||||
}
|
||||
|
||||
if (info_and_broadcast_to_nodes[0].first.real_output_num <= 1) {
|
||||
|
@ -462,9 +294,9 @@ std::vector<AtomicAddUserInfo> AtomicCleanInsertter::FindOriginCNodeUsers(
|
|||
return reduce_user_nodes;
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::ProcessOriginCNodeUser(
|
||||
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
void AtomicCleanInserter::ProcessOriginCNodeUser(
|
||||
const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
// 1. Find users.
|
||||
auto reduce_user_nodes = FindOriginCNodeUsers(main_graph, composite_node, info_and_broadcast_to_nodes, mng);
|
||||
|
@ -481,24 +313,22 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(
|
|||
}
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
void AtomicCleanInserter::InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<CleanZeroUserInfo> &atomic_add_infos,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
auto origin_composite_node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_composite_node);
|
||||
|
||||
// Create broadcast node.
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_broadcast_to_nodes;
|
||||
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> info_and_broadcast_to_nodes;
|
||||
for (auto atomic_add_info : atomic_add_infos) {
|
||||
auto out_type = GetType(atomic_add_info.atomic_add_node)->cast<TensorTypePtr>();
|
||||
auto out_type = GetType(atomic_add_info.op_node)->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_type);
|
||||
auto broadcast_to_node =
|
||||
CreateAtomicCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id());
|
||||
auto broadcast_to_node = CreateCleanCompositeNode(atomic_add_info, main_graph, out_type->element()->type_id());
|
||||
(void)info_and_broadcast_to_nodes.emplace_back(atomic_add_info, broadcast_to_node);
|
||||
}
|
||||
|
||||
// Insert extra input(broadcast node output) to composite node, and make ReduceSum inplace-assign to it.
|
||||
// Note: InplaceAssign outputs will increase total memory because of fake out.
|
||||
ProcessOriginCNode(origin_composite_node, info_and_broadcast_to_nodes);
|
||||
|
||||
// Insert UpdateState + Load before origin ReduceSum's user to keep execution order.
|
||||
|
@ -512,7 +342,7 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c
|
|||
MS_LOG(INFO) << ss.str();
|
||||
}
|
||||
|
||||
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
||||
bool AtomicCleanInserter::Run(const FuncGraphPtr &func_graph) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
|
|
|
@ -24,14 +24,9 @@
|
|||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "common/graph_kernel/clean_inserter.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
struct AtomicAddInfo {
|
||||
CNodePtr atomic_add_node{nullptr};
|
||||
size_t reduce_real_output_index{0};
|
||||
size_t real_output_num{0};
|
||||
};
|
||||
|
||||
struct AtomicAddUserInfo {
|
||||
AnfNodePtr clean_node{nullptr};
|
||||
AnfNodePtr update_state_node{nullptr};
|
||||
|
@ -46,13 +41,13 @@ class AtomicAddChecker {
|
|||
static std::shared_ptr<AtomicAddChecker> Init();
|
||||
|
||||
bool Check(const AnfNodePtr &node);
|
||||
std::vector<AtomicAddInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
|
||||
std::vector<CleanZeroUserInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
|
||||
|
||||
protected:
|
||||
virtual bool SuitableForAtomicAdd(const AnfNodePtr &) { return false; }
|
||||
virtual bool FindCandidate(const AnfNodePtr &anf_node);
|
||||
virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
|
||||
std::vector<AtomicAddInfo> atomic_add_infos_;
|
||||
std::vector<CleanZeroUserInfo> atomic_add_infos_;
|
||||
PrimitivePtr target_type_{prim::kPrimReduceSum};
|
||||
};
|
||||
|
||||
|
@ -74,34 +69,26 @@ class AtomicAddCheckerAscend : public AtomicAddChecker {
|
|||
bool SuitableForAtomicAdd(const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
class AtomicCleanInsertter : public opt::Pass {
|
||||
class AtomicCleanInserter : public CleanInserter {
|
||||
public:
|
||||
explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {}
|
||||
~AtomicCleanInsertter() override = default;
|
||||
explicit AtomicCleanInserter(const std::string &name = "atomic_clean") : CleanInserter(name) {}
|
||||
~AtomicCleanInserter() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos);
|
||||
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes);
|
||||
virtual CNodePtr CreateAtomicCleanCompositeNode(const AtomicAddInfo &atomic_add_info,
|
||||
const KernelGraphPtr &main_graph, TypeId dst_type);
|
||||
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
|
||||
CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const AnfNodePtr &node) const;
|
||||
void CreateInplaceAssignNodeAndCorrectReturn(
|
||||
const FuncGraphPtr &sub_graph, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> ¶meters_infos);
|
||||
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
void InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<CleanZeroUserInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
|
||||
|
||||
void ProcessOriginCNodeUser(const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
|
||||
private:
|
||||
std::vector<AtomicAddUserInfo> FindOriginCNodeUsers(
|
||||
const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
const FuncGraphManagerPtr &mng) const;
|
||||
};
|
||||
using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>;
|
||||
using AtomicCleanInserterPtr = std::shared_ptr<AtomicCleanInserter>;
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) {
|
||||
void StitchAtomicCleanInserter::CorrectKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos) {
|
||||
// Change kernel build info.
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -56,9 +56,9 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
|
|||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
|
||||
const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
|
||||
int index) const {
|
||||
void StitchAtomicCleanInserter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
|
||||
const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
|
||||
int index) const {
|
||||
// Create depend node to hold execution order.
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
|
@ -70,23 +70,22 @@ void StitchAtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const
|
|||
user_cnode->set_input(IntToSize(index), depend_cnode);
|
||||
}
|
||||
|
||||
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
|
||||
const AnfNodePtr &new_parameter,
|
||||
const AtomicAddInfo &info) const {
|
||||
// add inplaceassign
|
||||
AnfNodePtr out_node = info.atomic_add_node; // Use result data itself, and set attr "fake_out" true.
|
||||
auto inplace_assign_node =
|
||||
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, out_node, out_node}, sub_graph,
|
||||
CNodePtr StitchAtomicCleanInserter::CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
|
||||
const CleanZeroUserInfo &info) const {
|
||||
// add assign
|
||||
AnfNodePtr out_node = info.op_node; // Use result data itself
|
||||
|
||||
auto assign_node =
|
||||
CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, out_node}, sub_graph,
|
||||
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
|
||||
SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
|
||||
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
|
||||
return inplace_assign_node;
|
||||
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), assign_node);
|
||||
return assign_node;
|
||||
}
|
||||
|
||||
void StitchAtomicCleanInsertter::ProcessOriginCNode(
|
||||
void StitchAtomicCleanInserter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes, bool atomic_add_attr) {
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
|
@ -106,12 +105,12 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(
|
|||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
|
||||
auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter, atomic_add_info);
|
||||
auto assign = CreateAssignNode(sub_graph, parameter, atomic_add_info);
|
||||
|
||||
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
|
||||
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after assign to avoid
|
||||
// elimination.
|
||||
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes =
|
||||
FindInnerCNodeUsers(stitch_node_, atomic_add_info.atomic_add_node);
|
||||
FindInnerCNodeUsers(stitch_node_, atomic_add_info.op_node);
|
||||
bool connected = false;
|
||||
for (const auto &[user_node, index] : reduce_user_nodes) {
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
|
@ -121,7 +120,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(
|
|||
std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
|
||||
if (!user_user.empty()) {
|
||||
auto pair = user_user[0];
|
||||
AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second);
|
||||
AddDepend(sub_graph, user_cnode, assign, pair.first, pair.second);
|
||||
}
|
||||
connected = true;
|
||||
}
|
||||
|
@ -134,8 +133,8 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(
|
|||
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
|
||||
const CNodePtr &target) const {
|
||||
std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInserter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
|
||||
const CNodePtr &target) const {
|
||||
auto node = inner_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
|
@ -151,8 +150,8 @@ std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNo
|
|||
return inner_user_nodes;
|
||||
}
|
||||
|
||||
std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
|
||||
if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()};
|
||||
std::pair<bool, CleanZeroUserInfo> StitchAtomicCleanInserter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
|
||||
if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, CleanZeroUserInfo()};
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
|
@ -163,16 +162,16 @@ std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(co
|
|||
common::AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" &&
|
||||
IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
|
||||
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
|
||||
AtomicAddInfo info;
|
||||
info.atomic_add_node = n->cast<CNodePtr>();
|
||||
CleanZeroUserInfo info;
|
||||
info.op_node = n->cast<CNodePtr>();
|
||||
stitch_node_ = anf_node;
|
||||
return {true, info};
|
||||
}
|
||||
}
|
||||
return {false, AtomicAddInfo()};
|
||||
return {false, CleanZeroUserInfo()};
|
||||
}
|
||||
|
||||
bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
||||
bool StitchAtomicCleanInserter::Run(const FuncGraphPtr &func_graph) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
|
|
|
@ -26,31 +26,31 @@
|
|||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
|
||||
class StitchAtomicCleanInserter : public AtomicCleanInserter {
|
||||
public:
|
||||
StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {}
|
||||
~StitchAtomicCleanInsertter() override = default;
|
||||
StitchAtomicCleanInserter() : AtomicCleanInserter("stitch_atomic_clean") {}
|
||||
~StitchAtomicCleanInserter() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &clean_infos) override;
|
||||
void ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) override;
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos) override;
|
||||
void ProcessOriginCNode(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
bool atomic_add_attr = true) override;
|
||||
|
||||
private:
|
||||
CNodePtr CreateInplaceAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
|
||||
const AtomicAddInfo &info) const;
|
||||
CNodePtr CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
|
||||
const CleanZeroUserInfo &info) const;
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
|
||||
const CNodePtr &target) const;
|
||||
std::pair<bool, AtomicAddInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
|
||||
std::pair<bool, CleanZeroUserInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
|
||||
|
||||
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &user_node, int index) const;
|
||||
|
||||
AnfNodePtr stitch_node_{nullptr};
|
||||
};
|
||||
using StitchAtomicCleanInsertterPtr = std::shared_ptr<StitchAtomicCleanInsertter>;
|
||||
using StitchAtomicCleanInserterPtr = std::shared_ptr<StitchAtomicCleanInserter>;
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/clean_inserter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
CNodePtr CreateAssign(const FuncGraphPtr &sub_graph,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> ¶meters_infos, size_t idx) {
|
||||
if (idx >= parameters_infos.size()) {
|
||||
MS_LOG(EXCEPTION) << "idx " << idx << " is out of range [0, " << parameters_infos.size() << ")";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
|
||||
const auto &target_node = parameters_infos[idx].first.op_node;
|
||||
const auto &new_parameter = parameters_infos[idx].second;
|
||||
|
||||
auto node =
|
||||
CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, target_node}, sub_graph,
|
||||
{.format = GetFormat(target_node), .shape = GetShape(target_node), .type = GetType(target_node)});
|
||||
return node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CleanInserter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos) {
|
||||
// Change kernel build info.
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
MS_EXCEPTION_IF_NULL(origin_kernel_build_info);
|
||||
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
|
||||
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
|
||||
|
||||
std::vector<std::string> &new_inputs_format = origin_inputs_format;
|
||||
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
|
||||
for (const auto &clean_info : clean_infos) {
|
||||
auto &new_input = clean_info.second;
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
}
|
||||
|
||||
auto new_selected_info = BuildSelectKernelBuildInfo(
|
||||
new_inputs_format, new_inputs_type, origin_kernel_build_info->GetAllOutputFormats(),
|
||||
origin_kernel_build_info->GetAllOutputDeviceTypes(), origin_kernel_build_info->processor());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
void CleanInserter::CreateAssignNodeAndCorrectReturn(
|
||||
const FuncGraphPtr &sub_graph, const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> ¶meters_infos) {
|
||||
std::map<size_t, size_t> target_indices;
|
||||
for (size_t i = 0; i < parameters_infos.size(); ++i) {
|
||||
target_indices[parameters_infos[i].first.real_output_index + 1] = i;
|
||||
}
|
||||
|
||||
// Change output to Assign node.
|
||||
auto output = sub_graph->output();
|
||||
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
for (size_t i = 1; i < output_cnode->inputs().size(); ++i) {
|
||||
auto iter = target_indices.find(i);
|
||||
if (iter == target_indices.end()) continue;
|
||||
auto inplace = CreateAssign(sub_graph, parameters_infos, iter->second);
|
||||
output_cnode->set_input(i, inplace);
|
||||
}
|
||||
} else if (parameters_infos.size() == 1) {
|
||||
auto inplace = CreateAssign(sub_graph, parameters_infos, 0);
|
||||
sub_graph->set_output(inplace);
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr CleanInserter::InsertUpdateState(const FuncGraphPtr &main_graph, const AnfNodePtr &node) const {
|
||||
// Insert update_state_node, need mount a monad node.
|
||||
auto u = NewValueNode(kUMonad);
|
||||
u->set_abstract(kUMonad->ToAbstract());
|
||||
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, node};
|
||||
auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
|
||||
update_state_cnode->set_abstract(kUMonad->ToAbstract());
|
||||
main_graph->AddNode(update_state_cnode);
|
||||
return update_state_cnode;
|
||||
}
|
||||
|
||||
CNodePtr CleanInserter::CreateCleanCompositeNode(const CleanZeroUserInfo &op_info, const FuncGraphPtr &main_graph,
|
||||
TypeId dst_type) {
|
||||
std::set<TypeId> data_support = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
|
||||
if (!std::any_of(data_support.cbegin(), data_support.cend(), [&dst_type](TypeId type) { return dst_type == type; })) {
|
||||
MS_LOG(EXCEPTION) << "For CreateCleanCompositeNode, the data type: " << TypeIdToString(dst_type, true)
|
||||
<< " is not in supported list: [float16, float32, float64].";
|
||||
}
|
||||
|
||||
// Create zero value which will be broadcast to target shape.
|
||||
auto format = GetFormat(op_info.op_node);
|
||||
auto dtype = (dst_type == kNumberTypeFloat16) ? kNumberTypeFloat32 : dst_type;
|
||||
ValueNodePtr value_node;
|
||||
if (dtype == kNumberTypeFloat32) {
|
||||
value_node = CreateScalarTensorValueNode<float>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
|
||||
static_cast<float>(0), sizeof(float));
|
||||
} else {
|
||||
value_node = CreateScalarTensorValueNode<double>({.format = format, .shape = {1}, .type = TypeIdToType(dtype)},
|
||||
static_cast<double>(0), sizeof(double));
|
||||
}
|
||||
|
||||
// Create composite op's sub-graph.
|
||||
auto new_sub_graph = std::make_shared<FuncGraph>();
|
||||
|
||||
AnfNodePtr broadcast_input_node;
|
||||
if (dst_type == kNumberTypeFloat16) {
|
||||
AnfNodePtrList cast_inputs = {NewValueNode(prim::kPrimCast), value_node};
|
||||
auto cast_node_inner =
|
||||
CreateCNode(cast_inputs, new_sub_graph, {.format = format, .shape = {1}, .type = TypeIdToType(dst_type)});
|
||||
SetNodeAttrSafely("dst_type", MakeValue("float32"), cast_node_inner);
|
||||
broadcast_input_node = cast_node_inner;
|
||||
} else {
|
||||
broadcast_input_node = value_node;
|
||||
}
|
||||
|
||||
// Create broadcast basic op.
|
||||
auto dst_shape_vec = GetShape(op_info.op_node);
|
||||
AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
|
||||
auto broadcast_to_node_inner = CreateCNode(
|
||||
clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(op_info.op_node)});
|
||||
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(op_info.op_node)), broadcast_to_node_inner);
|
||||
|
||||
// Makeup sub-graph.
|
||||
new_sub_graph->set_output(broadcast_to_node_inner);
|
||||
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
|
||||
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
|
||||
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
|
||||
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
|
||||
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
|
||||
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
|
||||
|
||||
return broadcast_to_composite_node;
|
||||
}
|
||||
|
||||
void CleanInserter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes, bool atomic_add_attr) {
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
// Add input
|
||||
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> parameters_infos;
|
||||
for (const auto &[atomic_add_info, new_input] : info_and_broadcast_to_nodes) {
|
||||
// Add atomic attribute to target node.
|
||||
if (atomic_add_attr) SetNodeAttrSafely("enable_atomic_add", MakeValue(true), atomic_add_info.op_node);
|
||||
|
||||
// add parameter
|
||||
auto parameter = sub_graph->add_parameter();
|
||||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
(void)parameters_infos.emplace_back(atomic_add_info, parameter);
|
||||
}
|
||||
|
||||
auto inputs = composite_node->cast<CNodePtr>()->inputs();
|
||||
(void)std::transform(info_and_broadcast_to_nodes.cbegin(), info_and_broadcast_to_nodes.cend(),
|
||||
std::back_inserter(inputs),
|
||||
[](const std::pair<CleanZeroUserInfo, AnfNodePtr> &pair_item) { return pair_item.second; });
|
||||
composite_node->cast<CNodePtr>()->set_inputs(inputs);
|
||||
|
||||
CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
|
||||
CorrectKernelBuildInfo(composite_node, info_and_broadcast_to_nodes);
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
struct CleanZeroUserInfo {
|
||||
CNodePtr op_node{nullptr};
|
||||
size_t real_output_index{0};
|
||||
size_t real_output_num{0};
|
||||
};
|
||||
|
||||
class CleanInserter : public opt::Pass {
|
||||
public:
|
||||
explicit CleanInserter(const std::string &name = "clean_inserter") : Pass(name) {}
|
||||
~CleanInserter() override = default;
|
||||
|
||||
protected:
|
||||
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &clean_infos);
|
||||
virtual CNodePtr CreateCleanCompositeNode(const CleanZeroUserInfo &op_info, const FuncGraphPtr &main_graph,
|
||||
TypeId dst_type);
|
||||
CNodePtr InsertUpdateState(const FuncGraphPtr &main_graph, const AnfNodePtr &node) const;
|
||||
void CreateAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> ¶meters_infos);
|
||||
virtual void ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> &info_and_broadcast_to_nodes,
|
||||
bool atomic_add_attr = true);
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_INSERTER_H_
|
|
@ -44,7 +44,7 @@ class TsaChecker : public AtomicAddChecker {
|
|||
}
|
||||
|
||||
for (auto atomic_add_info : atomic_add_infos_) {
|
||||
auto tsa_cnode = atomic_add_info.atomic_add_node;
|
||||
auto tsa_cnode = atomic_add_info.op_node;
|
||||
if (!utils::isa<ParameterPtr>(tsa_cnode->input(1))) {
|
||||
return false;
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::FindTsaFirstRealInputIn
|
|||
}
|
||||
|
||||
std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstNode(
|
||||
const KernelGraphPtr &main_graph, const AtomicAddInfo &atomic_add_info, const AnfNodePtr &node) {
|
||||
const KernelGraphPtr &main_graph, const CleanZeroUserInfo &atomic_add_info, const AnfNodePtr &node) {
|
||||
auto mng = main_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(main_graph, true);
|
||||
|
@ -94,7 +94,7 @@ std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstN
|
|||
}
|
||||
|
||||
// Find first input of tsa
|
||||
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.atomic_add_node, node);
|
||||
auto tsa_first_input = FindTsaFirstRealInputInGraph(main_graph, atomic_add_info.op_node, node);
|
||||
auto users = mng->node_users()[tsa_first_input.first];
|
||||
if (users.size() == 1 &&
|
||||
!(utils::isa<ValueNodePtr>(tsa_first_input.first) || utils::isa<ParameterPtr>(tsa_first_input.first))) {
|
||||
|
@ -147,7 +147,7 @@ std::pair<AnfNodePtr, size_t> TsaAtomicAddToFirstTensor::GetOrCreateNewTsaFirstN
|
|||
}
|
||||
|
||||
void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos) {
|
||||
const AnfNodePtr &composite_node, const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_infos) {
|
||||
// Change kernel build info with modify input
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -176,7 +176,7 @@ void TsaAtomicAddToFirstTensor::ChangeKernelBuildInfo(
|
|||
}
|
||||
|
||||
void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
|
||||
const AnfNodePtr &composite_node, const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes) {
|
||||
const AnfNodePtr &composite_node, const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_nodes) {
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
|
@ -185,8 +185,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
|
|||
}
|
||||
|
||||
// Modify input
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> parameters_infos;
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_tsa_outers;
|
||||
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> parameters_infos;
|
||||
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> info_and_tsa_outers;
|
||||
for (const auto &[atomic_add_info, outer_node, tsa_first_input_index] : outer_nodes) {
|
||||
composite_node->cast<CNodePtr>()->set_input(tsa_first_input_index + 1, outer_node);
|
||||
auto parameter = sub_graph->parameters()[tsa_first_input_index];
|
||||
|
@ -194,7 +194,7 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
|
|||
(void)info_and_tsa_outers.emplace_back(atomic_add_info, outer_node);
|
||||
}
|
||||
|
||||
CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
|
||||
CreateAssignNodeAndCorrectReturn(sub_graph, parameters_infos);
|
||||
ChangeKernelBuildInfo(composite_node, outer_nodes);
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
|
@ -205,14 +205,14 @@ void TsaAtomicAddToFirstTensor::ProcessOriginalCNode(
|
|||
}
|
||||
|
||||
void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos,
|
||||
const std::vector<CleanZeroUserInfo> &atomic_add_infos,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
auto origin_composite_node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_composite_node);
|
||||
|
||||
// Create identity node.
|
||||
std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
|
||||
std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> info_and_outer_nodes;
|
||||
std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> info_and_outer_nodes_with_index;
|
||||
std::vector<std::pair<CleanZeroUserInfo, AnfNodePtr>> info_and_outer_nodes;
|
||||
for (auto atomic_add_info : atomic_add_infos) {
|
||||
auto outer = GetOrCreateNewTsaFirstNode(main_graph, atomic_add_info, anf_node);
|
||||
(void)info_and_outer_nodes_with_index.emplace_back(atomic_add_info, outer.first, outer.second);
|
||||
|
@ -220,7 +220,6 @@ void TsaAtomicAddToFirstTensor::ProcessTsa(const KernelGraphPtr &main_graph, con
|
|||
}
|
||||
|
||||
// Insert extra input(broadcast node output) to composite node, and make origin TensorScatterAdd inplace-assign to it.
|
||||
// Note: InplaceAssign outputs will increase total memory because of fake out.
|
||||
ProcessOriginalCNode(origin_composite_node, info_and_outer_nodes_with_index);
|
||||
|
||||
// Insert UpdateState + Load before origin TensorScatterAdd's user to keep execution order.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,26 +36,26 @@ namespace mindspore::graphkernel {
|
|||
* output = Reshape(input_x)
|
||||
* fake_out = SubGraph'(output, indices, update) {
|
||||
* %0 = TensorScatterAdd(%para1, %para2, %para3)
|
||||
* %1 = InplaceAssign(%para1, %0, %0) // attrs{"fake_output":true}
|
||||
* %1 = Assign(%para1, %0, umond) //
|
||||
* return %1
|
||||
* }
|
||||
*/
|
||||
class TsaAtomicAddToFirstTensor : public AtomicCleanInsertter {
|
||||
class TsaAtomicAddToFirstTensor : public AtomicCleanInserter {
|
||||
public:
|
||||
TsaAtomicAddToFirstTensor() : AtomicCleanInsertter("tensor_scatter_add_atomic_add_to_first_tensor") {}
|
||||
TsaAtomicAddToFirstTensor() : AtomicCleanInserter("tensor_scatter_add_atomic_add_to_first_tensor") {}
|
||||
~TsaAtomicAddToFirstTensor() override = default;
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void ProcessOriginalCNode(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_nodes);
|
||||
const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_nodes);
|
||||
void ChangeKernelBuildInfo(const AnfNodePtr &composite_node,
|
||||
const std::vector<std::tuple<AtomicAddInfo, AnfNodePtr, size_t>> &outer_infos);
|
||||
const std::vector<std::tuple<CleanZeroUserInfo, AnfNodePtr, size_t>> &outer_infos);
|
||||
void ProcessTsa(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node,
|
||||
const std::vector<AtomicAddInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
|
||||
const std::vector<CleanZeroUserInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
|
||||
std::pair<AnfNodePtr, size_t> GetOrCreateNewTsaFirstNode(const KernelGraphPtr &main_graph,
|
||||
const AtomicAddInfo &atomic_add_info,
|
||||
const CleanZeroUserInfo &atomic_add_info,
|
||||
const AnfNodePtr &node);
|
||||
std::pair<AnfNodePtr, size_t> FindTsaFirstRealInputInGraph(const KernelGraphPtr &, const CNodePtr &tsa_node,
|
||||
const AnfNodePtr &node);
|
||||
|
|
|
@ -36,13 +36,13 @@ namespace mindspore::graphkernel {
|
|||
* output = broadcast_to(0.0) // attrs{"shape": [shape of origin output.]}
|
||||
* fake_out = SubGraph'(input_x, segment_ids, output) {
|
||||
* %0 = UnsortedSegmentSum(%para1, %para2)
|
||||
* %1 = InplaceAssign(%para3, %0, %0) // attrs{"fake_output":true}
|
||||
* %1 = Assign(%para3, %0, umond)
|
||||
* return %1
|
||||
* }
|
||||
*/
|
||||
class UssAtomicAdd : public AtomicCleanInsertter {
|
||||
class UssAtomicAdd : public AtomicCleanInserter {
|
||||
public:
|
||||
UssAtomicAdd() : AtomicCleanInsertter("unsorted_segment_sum_atomic_add_process") {}
|
||||
UssAtomicAdd() : AtomicCleanInserter("unsorted_segment_sum_atomic_add_process") {}
|
||||
~UssAtomicAdd() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue