forked from mindspore-Ecosystem/mindspore
NonTask Split Process
This commit is contained in:
parent
d429ea3f7d
commit
b465f21d90
|
@ -86,6 +86,7 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) {
|
|||
IndependentNodeOutputProcess(graph);
|
||||
SummaryInputProcess(graph);
|
||||
RefNodeProcess(graph);
|
||||
NonTaskSplitProcess(graph);
|
||||
UnReuseNodeProcess(graph);
|
||||
GenContiguousList(graph);
|
||||
GetNextOutputProcess(graph);
|
||||
|
@ -535,6 +536,27 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "Special Tensor total size: RefNode: input " << total_input_size << " output " << total_output_size;
|
||||
}
|
||||
|
||||
void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_cnodes = graph->execution_order();
|
||||
for (const auto &kernel : kernel_cnodes) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (op_name == kSplitOpName && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) {
|
||||
std::vector<size_t> refnode_input_output;
|
||||
auto node = nodes_map_[kernel.get()];
|
||||
auto input_tensor = node->input_tensors_[0];
|
||||
input_tensor->type_ = kRefNodeInput;
|
||||
refnode_input_output.push_back(input_tensor->GetId());
|
||||
|
||||
for (auto &output_tensor : node->output_tensors_) {
|
||||
output_tensor->type_ = kRefNodeOutput;
|
||||
refnode_input_output.push_back(output_tensor->GetId());
|
||||
}
|
||||
ref_node_constraints_.push_back(refnode_input_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
vector<string> full_name_list = {};
|
||||
|
|
|
@ -114,6 +114,7 @@ class Somas {
|
|||
void IndependentNodeOutputProcess(const session::KernelGraph *graph);
|
||||
void SummaryInputProcess(const session::KernelGraph *graph);
|
||||
void RefNodeProcess(const session::KernelGraph *graph);
|
||||
void NonTaskSplitProcess(const session::KernelGraph *graph);
|
||||
void UnReuseNodeProcess(const session::KernelGraph *graph);
|
||||
SomasTensorPtr CreateGapTensor(size_t gap_tensor_id);
|
||||
void GenContiguousList(const session::KernelGraph *graph);
|
||||
|
|
|
@ -136,7 +136,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope());
|
||||
auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr);
|
||||
if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) {
|
||||
if (op_name == kSplitOpName && AnfAlgo::HasNodeAttr(kAttrNonTask, anf_node_ptr)) {
|
||||
MS_LOG(INFO) << "Skip task generation for NnTask op " << anf_node_ptr->fullname_with_scope();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (op_name != kAtomicAddrCleanOpName) {
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) {
|
||||
if (op_name == kDynamicRNNOpName && i == 3) {
|
||||
continue;
|
||||
|
@ -153,6 +158,21 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
|||
AddressPtr input = std::make_shared<Address>();
|
||||
input->addr = device_address->ptr_;
|
||||
input->size = device_address->size_;
|
||||
|
||||
auto prenode_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
|
||||
if (AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitOpName &&
|
||||
AnfAlgo::HasNodeAttr(kAttrNonTask, prenode_with_index.first->cast<CNodePtr>())) {
|
||||
// use memory offset to implement NonTask Type Split op
|
||||
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset
|
||||
// offset is split's output index * split's output size
|
||||
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0);
|
||||
input->addr =
|
||||
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size);
|
||||
MS_LOG(INFO) << "Change " << anf_node_ptr->fullname_with_scope() << "'s input " << i << " address to "
|
||||
<< split_input0_device_address->ptr_ << " + "
|
||||
<< "prenode_with_index.second * input->size";
|
||||
}
|
||||
|
||||
kernel_inputs.push_back(input);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue