Zero copy for atomic addr clean op.
Free the relationship between formal and real parameter. not skip the nopnode which is an ouput.
This commit is contained in:
parent
b9f13c6ce4
commit
044feecbd7
|
@ -252,7 +252,11 @@ void KernelGraphMgr::InitInternalOutputParameter(const AnfNodePtr &out_node, con
|
|||
builder.SetOutputsDeviceType({type});
|
||||
builder.SetOutputsFormat({format});
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
||||
// If the flag is enable, it means the graph would run in subgraph sink mode, the internal parameter cannot share
|
||||
// the same device address.
|
||||
if (!node_graph->has_flag(kFlagEnableZeroCopyInGraph)) {
|
||||
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
||||
}
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
|
||||
parameter->Shape()->cast<abstract::BaseShapePtr>());
|
||||
parameter->set_abstract(abstract);
|
||||
|
@ -873,10 +877,16 @@ void KernelGraphMgr::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &pa
|
|||
}
|
||||
|
||||
KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
|
||||
DeviceType device_target, bool common_opt) {
|
||||
DeviceType device_target, bool common_opt,
|
||||
bool is_enable_zero_copy) {
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||
auto graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Set the zero copy flag in subgraph sink mode.
|
||||
if (is_enable_zero_copy) {
|
||||
MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString();
|
||||
graph->set_flag(kFlagEnableZeroCopyInGraph, true);
|
||||
}
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
for (const auto &node : lst) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1146,6 +1156,17 @@ std::string KernelGraphMgr::AddPartialParametersMap(const AnfNodePtr &partial_no
|
|||
return graph_target;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsNeedAddPartialParameter(const AnfNodePtr &user, const std::string &kernel_target,
|
||||
const std::shared_ptr<KernelGraph> &graph) {
|
||||
// If the flag is enable, it means the graph would run in subgraph sink mode, the real parameter on partial
|
||||
// cannot share the same device address with the formal parameter.
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
return common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
|
||||
!ExistGraphCaller(user) && (!graph->has_flag(kFlagEnableZeroCopyInGraph));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
|
||||
const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const std::shared_ptr<KernelGraph> &backend_graph) {
|
||||
|
@ -1176,8 +1197,7 @@ void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, co
|
|||
if (internal_output) {
|
||||
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
|
||||
for (auto &user : users) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
|
||||
!ExistGraphCaller(user)) {
|
||||
if (IsNeedAddPartialParameter(user, kernel_target, backend_graph)) {
|
||||
auto partial_target = AddPartialParametersMap(user);
|
||||
if (partial_target != kNoTarget && partial_target != kernel_target) {
|
||||
unique_target = false;
|
||||
|
|
|
@ -46,9 +46,12 @@ class BACKEND_EXPORT KernelGraphMgr {
|
|||
KernelGraphMgr() {}
|
||||
virtual ~KernelGraphMgr() {}
|
||||
|
||||
// The parameter is_enable_zero_copy means if the parameter in graph can avoid copy when it is executed, and it is
|
||||
// true in subgraph sink mode, and the device address shared for partial parameters and internal parameters in graph
|
||||
// would be disabled.
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
|
||||
DeviceType device_target = DeviceType::kUnknown,
|
||||
bool common_opt = true);
|
||||
bool common_opt = true, bool is_enable_zero_copy = false);
|
||||
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph,
|
||||
|
|
|
@ -335,33 +335,83 @@ bool ZeroCopyTask::UpdateArgs(void *stream) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, const session::KernelGraph &graph,
|
||||
std::vector<KernelWithIndex> GetInputNodeWithIndex(const CNodePtr &node, const TaskPtr &task,
|
||||
const std::vector<KernelWithIndex> &output_with_indexs,
|
||||
std::set<std::pair<AnfNodePtr, size_t>> *node_to_offset) {
|
||||
std::vector<KernelWithIndex> input_node_with_indexs;
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
if (common::AnfAlgo::GetCNodeName(node) == kAtomicAddrCleanOpName) {
|
||||
// For atomic addr clean op, the args in task is not the input node of kernel, we should get the real input index
|
||||
// from the input node.
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
const auto &input = node->input(i + 1);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, input->cast<CNodePtr>())) {
|
||||
auto clean_output_indexs = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(input, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
MS_LOG(DEBUG) << "atomic addr clean index:" << index << " for node:" << input->fullname_with_scope();
|
||||
input_node_with_indexs.emplace_back(input, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (input_node_with_indexs.size() != (task->ArgsSize() / sizeof(void *))) {
|
||||
MS_LOG(ERROR) << "Invalid input size:" << input_node_with_indexs.size()
|
||||
<< " task size:" << (task->ArgsSize() / sizeof(void *)) << " for node:" << node->DebugString();
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) {
|
||||
input_node_with_indexs.emplace_back(nullptr, i);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t input_index_in_graph = AnfAlgo::GetInputGraphIdxByKernelIdx(node, i);
|
||||
KernelWithIndex input_with_index{node, input_index_in_graph};
|
||||
do {
|
||||
input_with_index = common::AnfAlgo::GetPrevNodeOutput(input_with_index.first, input_with_index.second, false);
|
||||
if (std::find_if(output_with_indexs.begin(), output_with_indexs.end(),
|
||||
[input_with_index](const KernelWithIndex &output) {
|
||||
const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
|
||||
return real_output == input_with_index;
|
||||
}) != output_with_indexs.end()) {
|
||||
break;
|
||||
}
|
||||
} while (input_with_index.first != nullptr && common::AnfAlgo::IsNopNode(input_with_index.first));
|
||||
MS_LOG(DEBUG) << "Add input node:" << input_with_index.first->fullname_with_scope()
|
||||
<< " index:" << input_with_index.second << " for node:" << node->fullname_with_scope();
|
||||
input_node_with_indexs.emplace_back(input_with_index);
|
||||
}
|
||||
}
|
||||
return input_node_with_indexs;
|
||||
}
|
||||
|
||||
void GenerateZeroCopyTaskForInput(const CNodePtr &node, const TaskPtr &task, const session::KernelGraph &graph,
|
||||
std::vector<ZeroCopyTaskPtr> *zero_copy_tasks,
|
||||
std::set<std::pair<AnfNodePtr, size_t>> *node_to_offset) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(zero_copy_tasks);
|
||||
MS_EXCEPTION_IF_NULL(node_to_offset);
|
||||
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph.output());
|
||||
const auto &ref_node_map = graph.GetRefMap();
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) {
|
||||
std::vector<KernelWithIndex> input_node_with_indexs =
|
||||
GetInputNodeWithIndex(node, task, output_with_indexs, node_to_offset);
|
||||
|
||||
for (size_t i = 0; i < input_node_with_indexs.size(); ++i) {
|
||||
KernelWithIndex input_with_index = input_node_with_indexs[i];
|
||||
const auto input = input_with_index.first;
|
||||
if (input == nullptr || node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t input_index_in_graph = AnfAlgo::GetInputGraphIdxByKernelIdx(node, i);
|
||||
const auto &input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph, true);
|
||||
const auto input = input_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<Parameter>()) {
|
||||
// 1. Input parameter.
|
||||
zero_copy_tasks->emplace_back(
|
||||
std::make_shared<tasksink::ParameterZeroCopyTask>(input, task->Args(), i * sizeof(void *), task->task_name()));
|
||||
node_to_offset->emplace(node, i);
|
||||
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
|
||||
<< " ptr from parameter input:" << input->DebugString();
|
||||
<< " ptr from parameter input:" << input->fullname_with_scope();
|
||||
} else if (input->isa<CNode>()) {
|
||||
// 2. Input which is graph output.
|
||||
if (std::find_if(output_with_indexs.begin(), output_with_indexs.end(),
|
||||
|
@ -373,7 +423,8 @@ void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, c
|
|||
input, input_with_index.second, task->Args(), i * sizeof(void *), task->task_name()));
|
||||
node_to_offset->emplace(node, i);
|
||||
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
|
||||
<< " ptr from cnode input:" << input->DebugString() << " cnode index:" << input_with_index.second;
|
||||
<< " ptr from cnode input:" << input->fullname_with_scope()
|
||||
<< " cnode index:" << input_with_index.second;
|
||||
} else {
|
||||
// 3. Input which is a ref node whose input is a parameter, like:
|
||||
// refnode(parameter, node1)
|
||||
|
@ -385,7 +436,7 @@ void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, c
|
|||
zero_copy_tasks->emplace_back(std::make_shared<tasksink::ParameterZeroCopyTask>(
|
||||
parameter, task->Args(), i * sizeof(void *), task->task_name()));
|
||||
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
|
||||
<< " ptr from parameter input:" << parameter->DebugString();
|
||||
<< " ptr from parameter input:" << parameter->fullname_with_scope();
|
||||
node_to_offset->emplace(node, i);
|
||||
}
|
||||
}
|
||||
|
@ -427,7 +478,7 @@ void GenerateZeroCopyTaskForOutput(const AnfNodePtr &node, const TaskPtr &task,
|
|||
std::make_shared<tasksink::CNodeZeroCopyTask>(ref_iter->second.first, ref_iter->second.second, task->Args(),
|
||||
input_index * sizeof(void *), task->task_name()));
|
||||
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i
|
||||
<< " ptr from cnode input:" << ref_iter->second.first->DebugString()
|
||||
<< " ptr from cnode input:" << ref_iter->second.first->fullname_with_scope()
|
||||
<< " cnode index:" << ref_iter->second.second;
|
||||
node_to_offset->emplace(node, input_index);
|
||||
zero_copy_ref_nodes->emplace(ref_iter->second);
|
||||
|
@ -436,7 +487,7 @@ void GenerateZeroCopyTaskForOutput(const AnfNodePtr &node, const TaskPtr &task,
|
|||
zero_copy_tasks->emplace_back(std::make_shared<tasksink::ParameterZeroCopyTask>(
|
||||
ref_iter->second.first, task->Args(), (input_num + i) * sizeof(void *), task->task_name()));
|
||||
MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " output index:" << i
|
||||
<< " ptr from parameter input:" << ref_iter->second.first->DebugString();
|
||||
<< " ptr from parameter input:" << ref_iter->second.first->fullname_with_scope();
|
||||
node_to_offset->emplace(node, input_num + i);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -255,12 +255,11 @@ void OptimizeNopNode(KernelGraph *graph) {
|
|||
graph->set_execution_order(new_execution_order);
|
||||
}
|
||||
|
||||
bool SetZeroCopyFlag(const KernelGraphPtr &graph, bool run_in_pynative) {
|
||||
bool IsEnableZeroCopy(bool run_in_pynative) {
|
||||
if (run_in_pynative) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
bool task_sink = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
|
@ -283,8 +282,6 @@ bool SetZeroCopyFlag(const KernelGraphPtr &graph, bool run_in_pynative) {
|
|||
if (common::GetEnv("DISABLE_ZERO_COPY") == "1") {
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString();
|
||||
graph->set_flag(kFlagEnableZeroCopyInGraph, true);
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -299,7 +296,8 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
auto nodes = segment->nodes_;
|
||||
auto device_terget = device_context->GetDeviceType();
|
||||
// Generate kernel graph.
|
||||
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, device_terget);
|
||||
KernelGraphPtr graph =
|
||||
session_->ConstructKernelGraph(nodes, outputs, device_terget, true, IsEnableZeroCopy(run_in_pynative));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
opt::EliminateIllegalDataTypePass(graph);
|
||||
SetGraphDependency(graph, segment);
|
||||
|
@ -477,10 +475,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
|
||||
}
|
||||
#endif
|
||||
// If the zero copy flag has been set in graph, the relationship between partial and parameter should be disabled.
|
||||
if (SetZeroCopyFlag(graph, run_in_pynative)) {
|
||||
session_->ClearPartialParameterMap();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_context->kernel_executor_);
|
||||
// Execute optimization pass.
|
||||
device_context->kernel_executor_->OptimizeGraph(graph);
|
||||
|
|
Loading…
Reference in New Issue