!25415 insert atomic ops process split

Merge pull request !25415 from liubuyu/master
This commit is contained in:
i-robot 2021-10-29 01:34:50 +00:00 committed by Gitee
commit 324a767fe4
4 changed files with 281 additions and 245 deletions

View File

@ -721,7 +721,7 @@ void AscendSession::BatchBuildKernel(const std::vector<std::shared_ptr<SessionTa
std::vector<CNodePtr> atomic_node_to_build;
for (auto &graph : graphs) {
device::ascend::KernelBuildPreprocess(graph.get());
device::ascend::InsertAtomicCleanOp(graph);
const auto &nodes = graph->execution_order();
std::copy(nodes.begin(), nodes.end(), std::back_inserter(atomic_node_to_build));
}
@ -998,10 +998,10 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
InitRuntimeResource();
// Compile all kernels parallel
BuildKernel(kernels);
// Some new kernel may be added after KernelBuildPreprocess, so collect and build kernels again
// Some new kernel may be added after InsertAtomicCleanOp, so collect and build kernels again
kernels.clear();
for (const auto &graph_item : single_op_graphs) {
device::ascend::KernelBuildPreprocess(graph_item.first.get());
device::ascend::InsertAtomicCleanOp(graph_item.first);
const auto &execution_order = graph_item.first->execution_order();
std::copy(execution_order.begin(), execution_order.end(), std::back_inserter(kernels));
}
@ -1078,7 +1078,7 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
// Insert CLearZero op
// prepare for next step from json get atomic info
BuildKernel(kernel_graph);
device::ascend::KernelBuildPreprocess(kernel_graph.get());
device::ascend::InsertAtomicCleanOp(kernel_graph);
device::KernelAdjust::GetInstance().InsertDeviceLoopCtrl(kernel_graph);
device::KernelAdjust::GetInstance().ProcessLoopSink(kernel_graph);
#ifdef ENABLE_DUMP_IR
@ -1098,7 +1098,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel
// Insert CLearZero op
// prepare for next step from json get atomic info
BuildKernel(kernel_graph);
device::ascend::KernelBuildPreprocess(kernel_graph.get());
device::ascend::InsertAtomicCleanOp(kernel_graph);
MS_LOG(INFO) << "Finish!";
}

View File

@ -15,7 +15,7 @@
*/
#include "runtime/device/ascend/kernel_build_ascend.h"
#include <algorithm>
#include <vector>
#include <string>
#include <memory>
@ -29,7 +29,6 @@
#include "backend/kernel_compiler/hccl/hccl_kernel_build.h"
#include "backend/kernel_compiler/rts/rt_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace device {
@ -116,7 +115,95 @@ static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
return akg_ret;
}
static std::vector<size_t> CalCleanZerosSize(const CNodePtr &pre_node) {
bool KernelBuild(const std::vector<CNodePtr> &kernels) {
TbeUtils::LoadCache();
return device::ascend::KernelBuildParallelCompile(kernels);
}
namespace {
bool IsAtomicNode(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto parameters_indexes = kernel_mod->GenParameters();
if (parameters_indexes.empty()) {
return false;
}
if (AnfAlgo::IsDynamicShape(kernel_node)) {
if (parameters_indexes.at(0) == 1) {
(void)parameters_indexes.erase(parameters_indexes.begin());
} else {
parameters_indexes.pop_back();
}
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
size_t param_num = parameters_indexes.size();
size_t total_num = input_num + output_num + workspace_num;
size_t pad_index = param_num;
for (; pad_index < total_num; ++pad_index) {
parameters_indexes.emplace_back(0);
}
for (size_t j = 0; j < input_num; ++j) {
if (parameters_indexes.at(j) == 1) {
MS_LOG(EXCEPTION) << "Atomic clean doesn't support clean input address, input index: " << j;
}
}
if (parameters_indexes.size() < total_num) {
MS_LOG(EXCEPTION) << "Parameters indexes size: " << parameters_indexes.size()
<< " less than total num: " << total_num;
}
// process output
std::vector<size_t> output_indexes = {};
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, kernel_node)) {
output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrAtomicOutputIndexs);
}
for (size_t i = 0; i < output_num; ++i) {
auto param_output = parameters_indexes.at(input_num + i);
if (param_output == 1) {
output_indexes.emplace_back(i);
MS_LOG(INFO) << "Atomic clear output index: " << i;
}
}
if (!output_indexes.empty()) {
std::set<size_t> s(output_indexes.begin(), output_indexes.end());
output_indexes.assign(s.begin(), s.end());
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexes), kernel_node);
}
// process workspace
std::vector<size_t> workspace_indexes = {};
for (size_t k = 0; k < workspace_num; ++k) {
auto param_workspace = parameters_indexes.at(input_num + output_num + k);
if (param_workspace == 1) {
workspace_indexes.emplace_back(k);
MS_LOG(INFO) << "Atomic clear workspace index: " << k;
}
}
if (!workspace_indexes.empty()) {
AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexes), kernel_node);
}
return !(workspace_indexes.empty() && output_indexes.empty());
}
bool IfAtomicOpNeedFusion(const size_t clean_total_num, const CNodePtr &first_node, const CNodePtr &current_node) {
if (first_node == nullptr || current_node == nullptr) {
return false;
}
auto first_graph_id = AnfAlgo::GetGraphId(first_node.get());
auto current_graph_id = AnfAlgo::GetGraphId(current_node.get());
if (clean_total_num >= kMaxAttrMemListSize || first_graph_id != current_graph_id) {
return true;
}
return false;
}
std::vector<size_t> GetClearSize(const CNodePtr &pre_node) {
MS_EXCEPTION_IF_NULL(pre_node);
auto kernel_mod = AnfAlgo::GetKernelMod(pre_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
@ -124,37 +211,38 @@ static std::vector<size_t> CalCleanZerosSize(const CNodePtr &pre_node) {
constexpr size_t kAlignBytes = 32 - 1;
// clean output
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
auto output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
auto output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
auto output_men_size = kernel_mod->GetOutputSizeList();
for (auto index : output_indexs) {
for (auto index : output_indexes) {
auto clean_item = (output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize;
clean_size_list.emplace_back(clean_item);
}
}
// clean workspace
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
auto workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
auto workspace_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList();
for (const auto &index : workspace_indexs) {
for (const auto &index : workspace_indexes) {
auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize;
clean_size_list.emplace_back(clean_item);
}
}
MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope();
MS_LOG(INFO) << "Clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope();
return clean_size_list;
}
static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
const mindspore::CNodePtr &pre_node, std::vector<mindspore::CNodePtr> *new_nodes) {
MS_EXCEPTION_IF_NULL(kernel_graph);
CNodePtr NewAtomicOp(const CNodePtr &pre_node) {
MS_EXCEPTION_IF_NULL(pre_node);
MS_EXCEPTION_IF_NULL(new_nodes);
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.push_back(pre_node);
auto func_graph = pre_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
@ -164,123 +252,115 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr
MS_EXCEPTION_IF_NULL(builder);
builder->SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
auto clean_size = CalCleanZerosSize(pre_node);
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero);
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get());
new_nodes->push_back(clear_zero);
return clear_zero;
}
static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
const mindspore::CNodePtr &first_clear_node,
const std::vector<AnfNodePtr> &fusion_clear_inputs,
const std::vector<size_t> &clean_size_list,
std::vector<mindspore::CNodePtr> *new_nodes) {
void InsertFusionAtomicOp(const CNodePtr &first_clear_node, const std::vector<AnfNodePtr> &fusion_clear_inputs,
const std::vector<size_t> &clean_size_list, CleanOpsMap *clean_ops) {
MS_EXCEPTION_IF_NULL(first_clear_node);
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end());
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract);
clear_zero->set_abstract(abstract);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
MS_EXCEPTION_IF_NULL(clean_ops);
auto clear_zero = NewAtomicOp(first_clear_node);
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(first_clear_node.get()), clear_zero.get());
auto it = std::find(new_nodes->begin(), new_nodes->end(), first_clear_node);
if (it != new_nodes->end()) {
new_nodes->insert(it, clear_zero);
} else {
new_nodes->insert(new_nodes->begin(), clear_zero);
(*clean_ops)[first_clear_node].emplace_back(clear_zero);
}
void InsertAtomicOpForNormalOp(const mindspore::CNodePtr &pre_node, CleanOpsMap *clean_ops) {
MS_EXCEPTION_IF_NULL(pre_node);
MS_EXCEPTION_IF_NULL(clean_ops);
auto clear_zero = NewAtomicOp(pre_node);
auto clean_size = GetClearSize(pre_node);
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero);
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get());
(*clean_ops)[pre_node].emplace_back(clear_zero);
}
void SpecialAkgOps(const std::string &op_name, const CNodePtr &node, CleanOpsMap *clean_ops) {
MS_EXCEPTION_IF_NULL(clean_ops);
if (op_name == prim::kPrimMaxPoolGrad->name() && AnfAlgo::GetKernelType(node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.push_back(node);
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
clear_zero->set_kernel_info(kernel_info);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract);
AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
SelectKernelInfo(clear_zero);
// set the distinction label of clear same with anf
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(node.get()), clear_zero.get());
(*clean_ops)[node].emplace_back(clear_zero);
}
}
static bool IsAtomicNode(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto parameters_indexs = kernel_mod->GenParameters();
if (parameters_indexs.empty()) {
return false;
}
if (AnfAlgo::IsDynamicShape(kernel_node)) {
if (parameters_indexs.at(0) == 1) {
(void)parameters_indexs.erase(parameters_indexs.begin());
} else {
parameters_indexs.pop_back();
void ProcessAtomicFusion(const std::vector<CNodePtr> &kernels, CleanOpsMap *clean_ops) {
MS_EXCEPTION_IF_NULL(clean_ops);
std::vector<size_t> clean_size_list;
std::vector<AnfNodePtr> fusion_clear_inputs;
CNodePtr first_node = nullptr;
for (const auto &anf_node : kernels) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
SpecialAkgOps(apply_function_name, anf_node, clean_ops);
if (AnfAlgo::HasNodeAttr(kAttrNeedAtomic, anf_node) && AnfAlgo::GetNodeAttr<bool>(anf_node, kAttrNeedAtomic)) {
auto clean_sizes = GetClearSize(anf_node);
if (!clean_sizes.empty()) {
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
if (IfAtomicOpNeedFusion(clean_total_num, first_node, anf_node)) {
// create clean node
InsertFusionAtomicOp(first_node, fusion_clear_inputs, clean_size_list, clean_ops);
clean_size_list.clear();
fusion_clear_inputs.clear();
first_node = nullptr;
}
if (fusion_clear_inputs.empty()) {
first_node = anf_node;
}
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
fusion_clear_inputs.emplace_back(anf_node);
MS_LOG(DEBUG) << "The fusion_clear_inputs size: " << fusion_clear_inputs.size()
<< ", clean_size_list: " << clean_size_list.size();
}
}
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
size_t param_num = parameters_indexs.size();
size_t total_num = input_num + output_num + workspace_num;
size_t pad_index = param_num;
for (; pad_index < total_num; ++pad_index) {
parameters_indexs.emplace_back(0);
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
// create clean node
InsertFusionAtomicOp(first_node, fusion_clear_inputs, clean_size_list, clean_ops);
}
}
} // namespace
for (size_t j = 0; j < input_num; ++j) {
if (parameters_indexs.at(j) == 1) {
MS_LOG(EXCEPTION) << "Atomic addr clean doesn't support clean input address, input index: " << j;
void InsertAtomicOps(const std::vector<CNodePtr> &kernels, CleanOpsMap *clean_ops) {
// fusion
MS_EXCEPTION_IF_NULL(clean_ops);
static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
if (enable_fusion_clear) {
ProcessAtomicFusion(kernels, clean_ops);
return;
}
// single
for (const auto &node : kernels) {
std::string apply_function_name = AnfAlgo::GetCNodeName(node);
SpecialAkgOps(apply_function_name, node, clean_ops);
if (AnfAlgo::HasNodeAttr(kAttrNeedAtomic, node) && AnfAlgo::GetNodeAttr<bool>(node, kAttrNeedAtomic)) {
InsertAtomicOpForNormalOp(node, clean_ops);
}
}
if (parameters_indexs.size() < total_num) {
MS_LOG(EXCEPTION) << "Parameters indexes size: " << parameters_indexs.size()
<< " less than total num: " << total_num;
}
// process output
std::vector<size_t> output_indexs = {};
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, kernel_node)) {
output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrAtomicOutputIndexs);
}
for (size_t i = 0; i < output_num; ++i) {
auto param_output = parameters_indexs.at(input_num + i);
if (param_output == 1) {
output_indexs.emplace_back(i);
MS_LOG(INFO) << "Atomic clear output index: " << i;
}
}
if (!output_indexs.empty()) {
std::set<size_t> s(output_indexs.begin(), output_indexs.end());
output_indexs.assign(s.begin(), s.end());
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node);
}
// process workspace
std::vector<size_t> workspace_indexs = {};
for (size_t k = 0; k < workspace_num; ++k) {
auto param_workspace = parameters_indexs.at(input_num + output_num + k);
if (param_workspace == 1) {
workspace_indexs.emplace_back(k);
MS_LOG(INFO) << "Atomic clear workspace index: " << k;
}
}
if (!workspace_indexs.empty()) {
AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node);
}
return !(workspace_indexs.empty() && output_indexs.empty());
}
bool KernelBuild(const std::vector<CNodePtr> &kernels) {
TbeUtils::LoadCache();
return device::ascend::KernelBuildParallelCompile(kernels);
}
std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
const mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(const std::vector<CNodePtr> &exe_orders) {
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map;
for (auto &kernel : kernel_graph->execution_order()) {
for (auto &kernel : exe_orders) {
MS_EXCEPTION_IF_NULL(kernel);
auto input_num = AnfAlgo::GetInputTensorNum(kernel);
if (mindspore::session::AnfRuntimeAlgorithm::IsCommunicationOp(kernel)) {
@ -295,9 +375,8 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::IsCommunicationOp(cnode) || AnfAlgo::IsIndependentNode(cnode) ||
AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
// no need to add atomic for communication/independent/getnext op 's output
MS_LOG(INFO) << "No need to add atomic clean for op " << kernel_input.first->fullname_with_scope()
<< "'s output";
// no need to add atomic for communication or independent or get_next op's output
MS_LOG(INFO) << "No need insert atomic clean for op " << cnode->fullname_with_scope() << "'s output";
continue;
}
MS_LOG(INFO) << "Add atomic clean for single communication op input, comm:" << kernel->fullname_with_scope()
@ -319,139 +398,71 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
std::set<size_t> s(info.second.begin(), info.second.end());
info.second.assign(s.begin(), s.end());
}
return comm_input_info_map;
}
bool IsNeedClearZeroNodeFusion(const size_t clean_total_num, const mindspore::CNodePtr &first_node,
const mindspore::CNodePtr &current_node) {
if (first_node == nullptr || current_node == nullptr) {
return false;
void AddNeedInsertAtomicAttrForAllOps(const std::vector<CNodePtr> &exe_orders) {
if (exe_orders.empty()) {
return;
}
auto first_graph_id = AnfAlgo::GetGraphId(first_node.get());
auto current_graph_id = AnfAlgo::GetGraphId(current_node.get());
if (clean_total_num >= kMaxAttrMemListSize || first_graph_id != current_graph_id) {
return true;
}
return false;
}
static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
std::vector<CNodePtr> new_nodes;
std::vector<size_t> clean_size_list;
std::vector<AnfNodePtr> fusion_clear_inputs;
CNodePtr first_node = nullptr;
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
bool is_comm_input = false;
// set communication input output index attr
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(exe_orders);
for (const auto &anf_node : exe_orders) {
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
auto indexes = comm_input_info_map[anf_node];
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
is_comm_input = true;
}
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.push_back(anf_node);
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
clear_zero->set_kernel_info(kernel_info);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract);
AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
SelectKernelInfo(clear_zero);
// set the distinction label of clear same with anf
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
new_nodes.push_back(clear_zero);
} else if (is_comm_input ||
(AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) {
auto clean_sizes = CalCleanZerosSize(anf_node);
if (!clean_sizes.empty()) {
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
if (IsNeedClearZeroNodeFusion(clean_total_num, first_node, anf_node)) {
// create clean node
AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
clean_size_list.clear();
fusion_clear_inputs.clear();
first_node = nullptr;
}
if (fusion_clear_inputs.empty()) {
first_node = anf_node;
}
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
fusion_clear_inputs.emplace_back(anf_node);
MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size()
<< ", clean_size_list: " << clean_size_list.size();
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, anf_node)) {
auto output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node, kAttrAtomicOutputIndexs);
std::copy(indexes.begin(), indexes.end(), std::back_inserter(output_indexes));
std::set<size_t> tmp(output_indexes.begin(), output_indexes.end());
indexes.assign(tmp.begin(), tmp.end());
}
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
AnfAlgo::SetNodeAttr(kAttrNeedAtomic, MakeValue(true), anf_node);
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node)) {
AnfAlgo::SetNodeAttr(kAttrNeedAtomic, MakeValue(true), anf_node);
}
new_nodes.emplace_back(anf_node);
}
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
// create clean node
AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
}
kernel_graph->set_execution_order(new_nodes);
}
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
bool is_dynamic_graph = kernel_graph->is_dynamic_shape();
if (!is_dynamic_graph && enable_fusion_clear) {
TbeClearZeroNodeFusion(kernel_graph);
} else {
std::vector<CNodePtr> new_nodes;
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
bool is_comm_input = false;
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
auto indexes = comm_input_info_map[anf_node];
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
is_comm_input = true;
}
if (is_comm_input) {
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.push_back(anf_node);
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
clear_zero->set_kernel_info(kernel_info);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract);
AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
SelectKernelInfo(clear_zero);
// set the distinction label of clear same with anf
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
new_nodes.push_back(clear_zero);
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
if (IsAtomicNode(anf_node)) {
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
}
}
new_nodes.push_back(anf_node);
}
kernel_graph->set_execution_order(new_nodes);
std::vector<CNodePtr> GatherAllAtomicOps(const CleanOpsMap &node_maps) {
std::vector<CNodePtr> all_atomics;
auto iter = node_maps.begin();
while (iter != node_maps.end()) {
auto tmp = iter->second;
(void)std::copy(tmp.begin(), tmp.end(), std::back_inserter(all_atomics));
iter++;
}
return all_atomics;
}
void InsertAtomicCleanOpForMindRT(const std::vector<CNodePtr> &exe_orders, CleanOpsMap *maps) {
MS_EXCEPTION_IF_NULL(maps);
// assign attr
AddNeedInsertAtomicAttrForAllOps(exe_orders);
// insert atomic
InsertAtomicOps(exe_orders, maps);
std::vector<CNodePtr> all_atomics = GatherAllAtomicOps(*maps);
// build atomic
KernelBuild(all_atomics);
}
void InsertAtomicCleanOp(const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &exe_orders = kernel_graph->execution_order();
// assign attr
AddNeedInsertAtomicAttrForAllOps(exe_orders);
// insert atomic
CleanOpsMap node_to_cleans;
InsertAtomicOps(exe_orders, &node_to_cleans);
// update exec order
std::vector<CNodePtr> new_orders;
for (const auto &node : exe_orders) {
if (node_to_cleans.find(node) != node_to_cleans.end()) {
auto atomics = node_to_cleans[node];
(void)std::copy(atomics.begin(), atomics.end(), std::back_inserter(new_orders));
}
new_orders.push_back(node);
}
kernel_graph->set_execution_order(new_orders);
}
} // namespace ascend
} // namespace device

View File

@ -25,19 +25,43 @@ namespace mindspore {
namespace device {
namespace ascend {
using CommOpInputInfo = std::map<AnfNodePtr, std::vector<size_t>>;
using CleanOpsMap = std::map<CNodePtr, std::vector<CNodePtr>>;
/**
* @brief kernel build for ascend.
*/
bool KernelBuild(const std::vector<CNodePtr> &kernels);
/**
* @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn.
* @brief preprocess of kernel build for ascend, e.g. inserting clear_zero node for max_pool, bn.
* Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph
*/
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph);
void InsertAtomicCleanOp(const KernelGraphPtr &kernel_graph);
/**
* @brief Communication Op Input Info.
* @brief preprocess for mind rt
* */
CommOpInputInfo GetCommunicationOpInputInfo(const mindspore::session::KernelGraph *kernel_graph);
void InsertAtomicCleanOpForMindRT(const std::vector<CNodePtr> &exe_orders, CleanOpsMap *maps);
/**
* @brief communication op input info.
* */
CommOpInputInfo GetCommunicationOpInputInfo(const std::vector<CNodePtr> &exe_orders);
/**
* @brief insert atomic
* */
void InsertAtomicOps(const std::vector<CNodePtr> &exe_orders, CleanOpsMap *clean_ops);
/**
* @brief gather all atomics
* */
std::vector<CNodePtr> GatherAllAtomicOps(const CleanOpsMap &node_maps);
/**
* @brief add attr for op if need insert atomic
* */
void AddNeedInsertAtomicAttrForAllOps(const std::vector<CNodePtr> &exe_orders);
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -359,6 +359,7 @@ constexpr auto kAttrPerm = "perm";
constexpr auto kAttrTransposeFirst = "transpose_first";
constexpr auto kAttrAtomicAddMemSize = "automic_add_mem_size";
constexpr auto kAttrAtomicOutputIndexs = "atomic_output_clean_indexs";
constexpr auto kAttrNeedAtomic = "need_atomic";
constexpr auto kAttrAtomicWorkspaceIndexs = "atomic_workspace_clean_indexs";
constexpr auto kAttrSwitchCondition = "switch_condition";
constexpr auto kAttrDataType = "data_type";