forked from mindspore-Ecosystem/mindspore
fusion atomic clear node
This commit is contained in:
parent
d346a861bc
commit
1740aac860
|
@ -567,59 +567,46 @@ void Somas::InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel)
|
|||
auto stream = node->GetStream();
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(kernel->inputs()[1]);
|
||||
auto pre_node = (kernel->inputs()[1])->cast<CNodePtr>();
|
||||
auto iter = nodes_map_.find(pre_node.get());
|
||||
if (iter == nodes_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" << pre_node->fullname_with_scope()
|
||||
<< "] is not init.";
|
||||
}
|
||||
auto pre_somas_node = iter->second;
|
||||
// set clean output tensors
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
if (index > pre_somas_node->output_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope()
|
||||
<< "]'s outputs size " << pre_somas_node->output_tensors_.size();
|
||||
}
|
||||
auto input_somas_tensor = pre_somas_node->output_tensors_[index];
|
||||
MS_EXCEPTION_IF_NULL(input_somas_tensor);
|
||||
node->input_tensors_.push_back(input_somas_tensor);
|
||||
input_somas_tensor->destinations_.insert(node);
|
||||
input_somas_tensor->destinationStreams_.insert(stream);
|
||||
if (input_somas_tensor->lifetime_.start_ > node->GetId()) {
|
||||
input_somas_tensor->lifetime_.start_ = node->GetId();
|
||||
}
|
||||
node->ancestor_nodes_.insert(pre_somas_node);
|
||||
auto input_tensor_stream = input_somas_tensor->GetSourceStream();
|
||||
if (input_tensor_stream != stream) {
|
||||
stream->ancestor_streams_.insert(input_tensor_stream);
|
||||
input_somas_tensor->between_streams_ = true;
|
||||
auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]);
|
||||
auto pre_node = kernel->input(i + 1)->cast<CNodePtr>();
|
||||
auto iter = nodes_map_.find(pre_node.get());
|
||||
if (iter == nodes_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input ["
|
||||
<< pre_node->fullname_with_scope() << "] is not init.";
|
||||
}
|
||||
auto pre_somas_node = iter->second;
|
||||
// set clean output tensors
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
if (index > pre_somas_node->output_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope()
|
||||
<< "]'s outputs size " << pre_somas_node->output_tensors_.size();
|
||||
}
|
||||
auto input_somas_tensor = pre_somas_node->output_tensors_[index];
|
||||
MS_EXCEPTION_IF_NULL(input_somas_tensor);
|
||||
node->input_tensors_.push_back(input_somas_tensor);
|
||||
input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
|
||||
MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
|
||||
<< " 's output" << index << " to lifelong";
|
||||
}
|
||||
}
|
||||
}
|
||||
// set clean workspace tensors
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspace_indexs) {
|
||||
if (index > pre_somas_node->output_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope()
|
||||
<< "]'s Workspace size " << pre_somas_node->workspace_tensors_.size();
|
||||
}
|
||||
auto input_somas_tensor = pre_somas_node->workspace_tensors_[index];
|
||||
MS_EXCEPTION_IF_NULL(input_somas_tensor);
|
||||
node->input_tensors_.push_back(input_somas_tensor);
|
||||
input_somas_tensor->destinations_.insert(node);
|
||||
input_somas_tensor->destinationStreams_.insert(stream);
|
||||
if (input_somas_tensor->lifetime_.start_ > node->GetId()) {
|
||||
input_somas_tensor->lifetime_.start_ = node->GetId();
|
||||
}
|
||||
node->ancestor_nodes_.insert(pre_somas_node);
|
||||
auto input_tensor_stream = input_somas_tensor->GetSourceStream();
|
||||
if (input_tensor_stream != stream) {
|
||||
stream->ancestor_streams_.insert(input_tensor_stream);
|
||||
input_somas_tensor->between_streams_ = true;
|
||||
// set clean workspace tensors
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspace_indexs) {
|
||||
if (index > pre_somas_node->output_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope()
|
||||
<< "]'s Workspace size " << pre_somas_node->workspace_tensors_.size();
|
||||
}
|
||||
auto input_somas_tensor = pre_somas_node->workspace_tensors_[index];
|
||||
MS_EXCEPTION_IF_NULL(input_somas_tensor);
|
||||
node->input_tensors_.push_back(input_somas_tensor);
|
||||
input_somas_tensor->lifelong_value_ = kLifeLongGraphAll;
|
||||
MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_
|
||||
<< " 's workspace" << index << " to lifelong";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,8 @@ namespace device {
|
|||
namespace ascend {
|
||||
using mindspore::kernel::tbe::TbeUtils;
|
||||
using std::make_shared;
|
||||
constexpr size_t kMaxAttrMemListSize = 192;
|
||||
|
||||
static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
|
||||
kernel::KernelModPtr kernel_mod_ptr = nullptr;
|
||||
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
|
||||
|
@ -159,6 +161,30 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr
|
|||
new_nodes->push_back(clear_zero);
|
||||
}
|
||||
|
||||
static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
|
||||
const mindspore::CNodePtr &stream_node,
|
||||
const std::vector<AnfNodePtr> &fusion_clear_inputs,
|
||||
const std::vector<size_t> &clean_size_list,
|
||||
std::vector<mindspore::CNodePtr> *new_nodes) {
|
||||
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
||||
auto new_value_node = NewValueNode(clear_zero_prim);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
std::vector<AnfNodePtr> inputs = {new_value_node};
|
||||
inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end());
|
||||
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero);
|
||||
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
clear_zero->set_abstract(abstract);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
builder->SetKernelType(KernelType::TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
|
||||
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get());
|
||||
new_nodes->insert(new_nodes->begin(), clear_zero);
|
||||
}
|
||||
|
||||
static bool IsAtomicNode(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
|
||||
|
@ -264,23 +290,23 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
|
|||
return comm_input_info_map;
|
||||
}
|
||||
|
||||
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
|
||||
std::vector<CNodePtr> new_nodes;
|
||||
std::vector<size_t> clean_size_list;
|
||||
std::vector<AnfNodePtr> fusion_clear_inputs;
|
||||
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
|
||||
for (const auto &anf_node : kernel_graph->execution_order()) {
|
||||
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
bool is_comm_input = false;
|
||||
// set communication input output index attr
|
||||
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
|
||||
auto indexes = comm_input_info_map[anf_node];
|
||||
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
|
||||
is_comm_input = true;
|
||||
}
|
||||
|
||||
if (is_comm_input) {
|
||||
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
||||
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
||||
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
||||
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
||||
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
||||
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
||||
auto new_value_node = NewValueNode(clear_zero_prim);
|
||||
|
@ -299,15 +325,85 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
|||
// set the distinction label of clear same with anf
|
||||
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
|
||||
new_nodes.push_back(clear_zero);
|
||||
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
|
||||
if (IsAtomicNode(anf_node)) {
|
||||
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
||||
} else if (is_comm_input ||
|
||||
(AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) {
|
||||
auto clean_sizes = CalCleanZerosSize(anf_node);
|
||||
if (!clean_sizes.empty()) {
|
||||
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
|
||||
if (clean_total_num >= kMaxAttrMemListSize) {
|
||||
// create clean node
|
||||
auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front();
|
||||
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
||||
clean_size_list.clear();
|
||||
fusion_clear_inputs.clear();
|
||||
}
|
||||
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
|
||||
fusion_clear_inputs.emplace_back(anf_node);
|
||||
MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size()
|
||||
<< ", clean_size_list: " << clean_size_list.size();
|
||||
}
|
||||
}
|
||||
new_nodes.push_back(anf_node);
|
||||
new_nodes.emplace_back(anf_node);
|
||||
}
|
||||
|
||||
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
|
||||
// create clean node
|
||||
auto stream_node = new_nodes.front();
|
||||
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
||||
}
|
||||
kernel_graph->set_execution_order(new_nodes);
|
||||
}
|
||||
|
||||
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
|
||||
bool is_dynamic_graph = kernel_graph->is_dynamic_shape();
|
||||
if (!is_dynamic_graph && enable_fusion_clear) {
|
||||
TbeClearZeroNodeFusion(kernel_graph);
|
||||
} else {
|
||||
std::vector<CNodePtr> new_nodes;
|
||||
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
|
||||
for (const auto &anf_node : kernel_graph->execution_order()) {
|
||||
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
bool is_comm_input = false;
|
||||
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
|
||||
auto indexes = comm_input_info_map[anf_node];
|
||||
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
|
||||
is_comm_input = true;
|
||||
}
|
||||
|
||||
if (is_comm_input) {
|
||||
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
||||
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
||||
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
||||
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
||||
auto new_value_node = NewValueNode(clear_zero_prim);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
std::vector<AnfNodePtr> inputs = {new_value_node};
|
||||
inputs.push_back(anf_node);
|
||||
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
clear_zero->set_kernel_info(kernel_info);
|
||||
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
|
||||
SelectKernelInfo(clear_zero);
|
||||
// set the distinction label of clear same with anf
|
||||
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
|
||||
new_nodes.push_back(clear_zero);
|
||||
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
|
||||
if (IsAtomicNode(anf_node)) {
|
||||
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
||||
}
|
||||
}
|
||||
new_nodes.push_back(anf_node);
|
||||
}
|
||||
kernel_graph->set_execution_order(new_nodes);
|
||||
}
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,42 +92,48 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre
|
|||
void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
if (anf_node_ptr->inputs().size() != 2) {
|
||||
// akg process
|
||||
if (AnfAlgo::GetKernelType(anf_node_ptr) == KernelType::AKG_KERNEL) {
|
||||
LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs);
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
|
||||
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
|
||||
// set clean output addr
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
input->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(input->addr);
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->push_back(input);
|
||||
// tbe process
|
||||
auto input_tensor_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
// set clean output addr
|
||||
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[i + 1]);
|
||||
auto pre_node = anf_node_ptr->input(i + 1)->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
input->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(input->addr);
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->push_back(input);
|
||||
}
|
||||
MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
|
||||
}
|
||||
MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
|
||||
}
|
||||
// set clean workspace address
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspace_indexs) {
|
||||
auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(workspace);
|
||||
workspace->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(workspace->addr);
|
||||
workspace->size = device_address->size_;
|
||||
kernel_inputs->push_back(workspace);
|
||||
// set clean workspace address
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspace_indexs) {
|
||||
auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(workspace);
|
||||
workspace->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(workspace->addr);
|
||||
workspace->size = device_address->size_;
|
||||
kernel_inputs->push_back(workspace);
|
||||
}
|
||||
MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size();
|
||||
}
|
||||
}
|
||||
auto clear_mems = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAtomicAddMemSize);
|
||||
if (kernel_inputs->size() != clear_mems.size()) {
|
||||
MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:"
|
||||
MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:"
|
||||
<< kernel_inputs->size() << ",clean mem size" << clear_mems.size();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue