fusion atomic clear node

This commit is contained in:
zhaosida 2021-03-31 14:50:50 +08:00
parent d346a861bc
commit 1740aac860
3 changed files with 178 additions and 89 deletions

View File

@ -567,12 +567,14 @@ void Somas::InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel)
auto stream = node->GetStream();
MS_EXCEPTION_IF_NULL(stream);
MS_EXCEPTION_IF_NULL(kernel->inputs()[1]);
auto pre_node = (kernel->inputs()[1])->cast<CNodePtr>();
auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t i = 0; i < input_tensor_num; i++) {
MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]);
auto pre_node = kernel->input(i + 1)->cast<CNodePtr>();
auto iter = nodes_map_.find(pre_node.get());
if (iter == nodes_map_.end()) {
MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" << pre_node->fullname_with_scope()
<< "] is not init.";
MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input ["
<< pre_node->fullname_with_scope() << "] is not init.";
}
auto pre_somas_node = iter->second;
// set clean output tensors
@ -586,17 +588,9 @@ void Somas::InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel)
auto input_somas_tensor = pre_somas_node->output_tensors_[index];
MS_EXCEPTION_IF_NULL(input_somas_tensor);
node->input_tensors_.push_back(input_somas_tensor);
input_somas_tensor->destinations_.insert(node);
input_somas_tensor->destinationStreams_.insert(stream);
if (input_somas_tensor->lifetime_.start_ > node->GetId()) {
input_somas_tensor->lifetime_.start_ = node->GetId();
}
node->ancestor_nodes_.insert(pre_somas_node);
auto input_tensor_stream = input_somas_tensor->GetSourceStream();
if (input_tensor_stream != stream) {
stream->ancestor_streams_.insert(input_tensor_stream);
input_somas_tensor->between_streams_ = true;
}
input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
<< " 's output" << index << " to lifelong";
}
}
// set clean workspace tensors
@ -610,16 +604,9 @@ void Somas::InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel)
auto input_somas_tensor = pre_somas_node->workspace_tensors_[index];
MS_EXCEPTION_IF_NULL(input_somas_tensor);
node->input_tensors_.push_back(input_somas_tensor);
input_somas_tensor->destinations_.insert(node);
input_somas_tensor->destinationStreams_.insert(stream);
if (input_somas_tensor->lifetime_.start_ > node->GetId()) {
input_somas_tensor->lifetime_.start_ = node->GetId();
}
node->ancestor_nodes_.insert(pre_somas_node);
auto input_tensor_stream = input_somas_tensor->GetSourceStream();
if (input_tensor_stream != stream) {
stream->ancestor_streams_.insert(input_tensor_stream);
input_somas_tensor->between_streams_ = true;
input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
<< " 's workspace" << index << " to lifelong";
}
}
}

View File

@ -40,6 +40,8 @@ namespace device {
namespace ascend {
using mindspore::kernel::tbe::TbeUtils;
using std::make_shared;
constexpr size_t kMaxAttrMemListSize = 192;
static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
kernel::KernelModPtr kernel_mod_ptr = nullptr;
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
@ -159,6 +161,30 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr
new_nodes->push_back(clear_zero);
}
static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
const mindspore::CNodePtr &stream_node,
const std::vector<AnfNodePtr> &fusion_clear_inputs,
const std::vector<size_t> &clean_size_list,
std::vector<mindspore::CNodePtr> *new_nodes) {
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end());
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract);
clear_zero->set_abstract(abstract);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get());
new_nodes->insert(new_nodes->begin(), clear_zero);
}
static bool IsAtomicNode(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
@ -264,8 +290,77 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
return comm_input_info_map;
}
static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
std::vector<CNodePtr> new_nodes;
std::vector<size_t> clean_size_list;
std::vector<AnfNodePtr> fusion_clear_inputs;
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
bool is_comm_input = false;
// set communication input output index attr
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
auto indexes = comm_input_info_map[anf_node];
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
is_comm_input = true;
}
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
MS_EXCEPTION_IF_NULL(new_value_node);
std::vector<AnfNodePtr> inputs = {new_value_node};
inputs.push_back(anf_node);
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(clear_zero);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
clear_zero->set_kernel_info(kernel_info);
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract);
AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
SelectKernelInfo(clear_zero);
// set the distinction label of clear same with anf
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
new_nodes.push_back(clear_zero);
} else if (is_comm_input ||
(AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) {
auto clean_sizes = CalCleanZerosSize(anf_node);
if (!clean_sizes.empty()) {
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
if (clean_total_num >= kMaxAttrMemListSize) {
// create clean node
auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front();
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
clean_size_list.clear();
fusion_clear_inputs.clear();
}
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
fusion_clear_inputs.emplace_back(anf_node);
MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size()
<< ", clean_size_list: " << clean_size_list.size();
}
}
new_nodes.emplace_back(anf_node);
}
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
// create clean node
auto stream_node = new_nodes.front();
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
}
kernel_graph->set_execution_order(new_nodes);
}
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
bool is_dynamic_graph = kernel_graph->is_dynamic_shape();
if (!is_dynamic_graph && enable_fusion_clear) {
TbeClearZeroNodeFusion(kernel_graph);
} else {
std::vector<CNodePtr> new_nodes;
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) {
@ -307,6 +402,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
new_nodes.push_back(anf_node);
}
kernel_graph->set_execution_order(new_nodes);
}
}
} // namespace ascend
} // namespace device

View File

@ -92,13 +92,17 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre
void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) {
MS_EXCEPTION_IF_NULL(anf_node_ptr);
MS_EXCEPTION_IF_NULL(kernel_inputs);
if (anf_node_ptr->inputs().size() != 2) {
// akg process
if (AnfAlgo::GetKernelType(anf_node_ptr) == KernelType::AKG_KERNEL) {
LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs);
return;
}
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
// tbe process
auto input_tensor_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
for (size_t i = 0; i < input_tensor_num; i++) {
// set clean output addr
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[i + 1]);
auto pre_node = anf_node_ptr->input(i + 1)->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
for (auto index : clean_output_indexs) {
@ -124,10 +128,12 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
workspace->size = device_address->size_;
kernel_inputs->push_back(workspace);
}
MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size();
}
}
auto clear_mems = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAtomicAddMemSize);
if (kernel_inputs->size() != clear_mems.size()) {
MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:"
MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:"
<< kernel_inputs->size() << ",clean mem size" << clear_mems.size();
}
}