Zero copy for atomic addr clean op.

Free the relationship between formal and real parameter.

not skip the nopnode which is an ouput.
This commit is contained in:
gaoyong10 2022-10-31 14:52:22 +08:00
parent b9f13c6ce4
commit 044feecbd7
4 changed files with 95 additions and 27 deletions

View File

@ -252,7 +252,11 @@ void KernelGraphMgr::InitInternalOutputParameter(const AnfNodePtr &out_node, con
builder.SetOutputsDeviceType({type});
builder.SetOutputsFormat({format});
d_kernel_info->set_select_kernel_build_info(builder.Build());
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
// If the flag is enable, it means the graph would run in subgraph sink mode, the internal parameter cannot share
// the same device address.
if (!node_graph->has_flag(kFlagEnableZeroCopyInGraph)) {
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
}
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
parameter->Shape()->cast<abstract::BaseShapePtr>());
parameter->set_abstract(abstract);
@ -873,10 +877,16 @@ void KernelGraphMgr::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &pa
}
KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
DeviceType device_target, bool common_opt) {
DeviceType device_target, bool common_opt,
bool is_enable_zero_copy) {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
// Set the zero copy flag in subgraph sink mode.
if (is_enable_zero_copy) {
MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString();
graph->set_flag(kFlagEnableZeroCopyInGraph, true);
}
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
for (const auto &node : lst) {
MS_EXCEPTION_IF_NULL(node);
@ -1146,6 +1156,17 @@ std::string KernelGraphMgr::AddPartialParametersMap(const AnfNodePtr &partial_no
return graph_target;
}
namespace {
bool IsNeedAddPartialParameter(const AnfNodePtr &user, const std::string &kernel_target,
const std::shared_ptr<KernelGraph> &graph) {
// If the flag is enable, it means the graph would run in subgraph sink mode, the real parameter on partial
// cannot share the same device address with the formal parameter.
MS_EXCEPTION_IF_NULL(graph);
return common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
!ExistGraphCaller(user) && (!graph->has_flag(kFlagEnableZeroCopyInGraph));
}
} // namespace
void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph) {
@ -1176,8 +1197,7 @@ void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, co
if (internal_output) {
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
for (auto &user : users) {
if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
!ExistGraphCaller(user)) {
if (IsNeedAddPartialParameter(user, kernel_target, backend_graph)) {
auto partial_target = AddPartialParametersMap(user);
if (partial_target != kNoTarget && partial_target != kernel_target) {
unique_target = false;

View File

@ -46,9 +46,12 @@ class BACKEND_EXPORT KernelGraphMgr {
KernelGraphMgr() {}
virtual ~KernelGraphMgr() {}
// The parameter is_enable_zero_copy means if the parameter in graph can avoid copy when it is executed, and it is
// true in subgraph sink mode, and the device address shared for partial parameters and internal parameters in graph
// would be disabled.
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
DeviceType device_target = DeviceType::kUnknown,
bool common_opt = true);
bool common_opt = true, bool is_enable_zero_copy = false);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph,

View File

@ -335,33 +335,83 @@ bool ZeroCopyTask::UpdateArgs(void *stream) {
}
namespace {
void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, const session::KernelGraph &graph,
std::vector<KernelWithIndex> GetInputNodeWithIndex(const CNodePtr &node, const TaskPtr &task,
const std::vector<KernelWithIndex> &output_with_indexs,
std::set<std::pair<AnfNodePtr, size_t>> *node_to_offset) {
std::vector<KernelWithIndex> input_node_with_indexs;
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
if (common::AnfAlgo::GetCNodeName(node) == kAtomicAddrCleanOpName) {
// For atomic addr clean op, the args in task is not the input node of kernel, we should get the real input index
// from the input node.
for (size_t i = 0; i < input_num; ++i) {
const auto &input = node->input(i + 1);
MS_EXCEPTION_IF_NULL(input);
if (input->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, input->cast<CNodePtr>())) {
auto clean_output_indexs = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(input, kAttrAtomicOutputIndexs);
for (auto index : clean_output_indexs) {
MS_LOG(DEBUG) << "atomic addr clean index:" << index << " for node:" << input->fullname_with_scope();
input_node_with_indexs.emplace_back(input, index);
}
}
}
if (input_node_with_indexs.size() != (task->ArgsSize() / sizeof(void *))) {
MS_LOG(ERROR) << "Invalid input size:" << input_node_with_indexs.size()
<< " task size:" << (task->ArgsSize() / sizeof(void *)) << " for node:" << node->DebugString();
}
} else {
for (size_t i = 0; i < input_num; ++i) {
if (node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) {
input_node_with_indexs.emplace_back(nullptr, i);
continue;
}
size_t input_index_in_graph = AnfAlgo::GetInputGraphIdxByKernelIdx(node, i);
KernelWithIndex input_with_index{node, input_index_in_graph};
do {
input_with_index = common::AnfAlgo::GetPrevNodeOutput(input_with_index.first, input_with_index.second, false);
if (std::find_if(output_with_indexs.begin(), output_with_indexs.end(),
[input_with_index](const KernelWithIndex &output) {
const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
return real_output == input_with_index;
}) != output_with_indexs.end()) {
break;
}
} while (input_with_index.first != nullptr && common::AnfAlgo::IsNopNode(input_with_index.first));
MS_LOG(DEBUG) << "Add input node:" << input_with_index.first->fullname_with_scope()
<< " index:" << input_with_index.second << " for node:" << node->fullname_with_scope();
input_node_with_indexs.emplace_back(input_with_index);
}
}
return input_node_with_indexs;
}
void GenerateZeroCopyTaskForInput(const CNodePtr &node, const TaskPtr &task, const session::KernelGraph &graph,
std::vector<ZeroCopyTaskPtr> *zero_copy_tasks,
std::set<std::pair<AnfNodePtr, size_t>> *node_to_offset) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(zero_copy_tasks);
MS_EXCEPTION_IF_NULL(node_to_offset);
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph.output());
const auto &ref_node_map = graph.GetRefMap();
for (size_t i = 0; i < input_num; ++i) {
if (node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) {
std::vector<KernelWithIndex> input_node_with_indexs =
GetInputNodeWithIndex(node, task, output_with_indexs, node_to_offset);
for (size_t i = 0; i < input_node_with_indexs.size(); ++i) {
KernelWithIndex input_with_index = input_node_with_indexs[i];
const auto input = input_with_index.first;
if (input == nullptr || node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) {
continue;
}
size_t input_index_in_graph = AnfAlgo::GetInputGraphIdxByKernelIdx(node, i);
const auto &input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph, true);
const auto input = input_with_index.first;
MS_EXCEPTION_IF_NULL(input);
if (input->isa<Parameter>()) {
// 1. Input parameter.
zero_copy_tasks->emplace_back(
std::make_shared<tasksink::ParameterZeroCopyTask>(input, task->Args(), i * sizeof(void *), task->task_name()));
node_to_offset->emplace(node, i);
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
<< " ptr from parameter input:" << input->DebugString();
<< " ptr from parameter input:" << input->fullname_with_scope();
} else if (input->isa<CNode>()) {
// 2. Input which is graph output.
if (std::find_if(output_with_indexs.begin(), output_with_indexs.end(),
@ -373,7 +423,8 @@ void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, c
input, input_with_index.second, task->Args(), i * sizeof(void *), task->task_name()));
node_to_offset->emplace(node, i);
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
<< " ptr from cnode input:" << input->DebugString() << " cnode index:" << input_with_index.second;
<< " ptr from cnode input:" << input->fullname_with_scope()
<< " cnode index:" << input_with_index.second;
} else {
// 3. Input which is a ref node whose input is a parameter, like:
// refnode(parameter, node1)
@ -385,7 +436,7 @@ void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, c
zero_copy_tasks->emplace_back(std::make_shared<tasksink::ParameterZeroCopyTask>(
parameter, task->Args(), i * sizeof(void *), task->task_name()));
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
<< " ptr from parameter input:" << parameter->DebugString();
<< " ptr from parameter input:" << parameter->fullname_with_scope();
node_to_offset->emplace(node, i);
}
}
@ -427,7 +478,7 @@ void GenerateZeroCopyTaskForOutput(const AnfNodePtr &node, const TaskPtr &task,
std::make_shared<tasksink::CNodeZeroCopyTask>(ref_iter->second.first, ref_iter->second.second, task->Args(),
input_index * sizeof(void *), task->task_name()));
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
<< " ptr from cnode input:" << ref_iter->second.first->DebugString()
<< " ptr from cnode input:" << ref_iter->second.first->fullname_with_scope()
<< " cnode index:" << ref_iter->second.second;
node_to_offset->emplace(node, input_index);
zero_copy_ref_nodes->emplace(ref_iter->second);
@ -436,7 +487,7 @@ void GenerateZeroCopyTaskForOutput(const AnfNodePtr &node, const TaskPtr &task,
zero_copy_tasks->emplace_back(std::make_shared<tasksink::ParameterZeroCopyTask>(
ref_iter->second.first, task->Args(), (input_num + i) * sizeof(void *), task->task_name()));
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " output index:" << i
<< " ptr from parameter input:" << ref_iter->second.first->DebugString();
<< " ptr from parameter input:" << ref_iter->second.first->fullname_with_scope();
node_to_offset->emplace(node, input_num + i);
}
}

View File

@ -255,12 +255,11 @@ void OptimizeNopNode(KernelGraph *graph) {
graph->set_execution_order(new_execution_order);
}
bool SetZeroCopyFlag(const KernelGraphPtr &graph, bool run_in_pynative) {
bool IsEnableZeroCopy(bool run_in_pynative) {
if (run_in_pynative) {
return false;
}
MS_EXCEPTION_IF_NULL(graph);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
bool task_sink = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
@ -283,8 +282,6 @@ bool SetZeroCopyFlag(const KernelGraphPtr &graph, bool run_in_pynative) {
if (common::GetEnv("DISABLE_ZERO_COPY") == "1") {
return false;
}
MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString();
graph->set_flag(kFlagEnableZeroCopyInGraph, true);
return true;
}
} // namespace
@ -299,7 +296,8 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
auto nodes = segment->nodes_;
auto device_terget = device_context->GetDeviceType();
// Generate kernel graph.
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, device_terget);
KernelGraphPtr graph =
session_->ConstructKernelGraph(nodes, outputs, device_terget, true, IsEnableZeroCopy(run_in_pynative));
MS_EXCEPTION_IF_NULL(graph);
opt::EliminateIllegalDataTypePass(graph);
SetGraphDependency(graph, segment);
@ -477,10 +475,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
}
#endif
// If the zero copy flag has been set in graph, the relationship between partial and parameter should be disabled.
if (SetZeroCopyFlag(graph, run_in_pynative)) {
session_->ClearPartialParameterMap();
}
MS_EXCEPTION_IF_NULL(device_context->kernel_executor_);
// Execute optimization pass.
device_context->kernel_executor_->OptimizeGraph(graph);