forked from mindspore-Ecosystem/mindspore
!3388 Transfer tuple getitem's control to new added memcpy_async
Merge pull request !3388 from huanghui/r0.6
This commit is contained in:
commit
927a52fdf8
|
@ -20,11 +20,17 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "backend/kernel_compiler/hccl/hcom_util.h"
|
#include "backend/kernel_compiler/hccl/hcom_util.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "frontend/parallel/context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
||||||
|
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(parallel_context_instance);
|
||||||
|
if (parallel_context_instance->enable_parallel_optimizer()) {
|
||||||
|
return kOpFormat_DEFAULT;
|
||||||
|
}
|
||||||
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
||||||
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
||||||
|
|
|
@ -40,6 +40,38 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
|
||||||
return real_node->isa<ValueNode>();
|
return real_node->isa<ValueNode>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
|
||||||
|
const std::vector<AnfNodePtr> &memcpy_async_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(control_depend);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||||
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
|
||||||
|
make_tuple_inputs.emplace_back(hccl_node);
|
||||||
|
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||||
|
control_depend->set_input(IntToSize(index), make_tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
|
||||||
|
const std::vector<AnfNodePtr> &memcpy_async_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||||
|
auto manager = graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto &node_users = manager->node_users();
|
||||||
|
auto iter = node_users.find(tuple_getitem);
|
||||||
|
if (iter == node_users.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||||
|
}
|
||||||
|
for (const auto &node_index : iter->second) {
|
||||||
|
AnfNodePtr output = node_index.first;
|
||||||
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||||
|
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
|
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
|
||||||
const FuncGraphPtr &graph) {
|
const FuncGraphPtr &graph) {
|
||||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||||
|
@ -53,25 +85,13 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &m
|
||||||
}
|
}
|
||||||
// find hccl_node's output which is a control depend
|
// find hccl_node's output which is a control depend
|
||||||
for (const auto &node_index : iter->second) {
|
for (const auto &node_index : iter->second) {
|
||||||
if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) {
|
AnfNodePtr output = node_index.first;
|
||||||
continue;
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||||
|
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
|
||||||
|
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
|
||||||
|
DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list);
|
||||||
}
|
}
|
||||||
CNodePtr control_depend = node_index.first->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(control_depend);
|
|
||||||
std::vector<AnfNodePtr> new_inputs;
|
|
||||||
for (size_t i = 0; i < control_depend->size(); ++i) {
|
|
||||||
if (i == IntToSize(node_index.second)) {
|
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
||||||
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
|
|
||||||
make_tuple_inputs.emplace_back(hccl_node);
|
|
||||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
||||||
new_inputs.push_back(make_tuple);
|
|
||||||
} else {
|
|
||||||
new_inputs.push_back(control_depend->input(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
control_depend->set_inputs(new_inputs);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -148,11 +168,10 @@ const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_gr
|
||||||
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
|
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
if (!AnfAlgo::IsCommunicationOp(node)) {
|
if (!AnfAlgo::IsCommunicationOp(node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
InsertMemcpyAsync(func_graph, cnode);
|
InsertMemcpyAsync(func_graph, node->cast<CNodePtr>());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
Loading…
Reference in New Issue