forked from mindspore-Ecosystem/mindspore
!2292 gpu fix all nop node graph execute
Merge pull request !2292 from limingqi107/master
This commit is contained in:
commit
2e002ab64c
|
@ -228,7 +228,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap
|
|||
continue;
|
||||
}
|
||||
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
if (device_address->ptr_) {
|
||||
mem_manager_->FreeMemFromMemPool(device_address);
|
||||
}
|
||||
|
@ -289,7 +289,7 @@ bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) {
|
|||
for (auto &mem_swap_info : mem_swap_info_list) {
|
||||
auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_);
|
||||
const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_];
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false);
|
||||
|
||||
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
|
||||
mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address);
|
||||
|
@ -379,7 +379,8 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k
|
|||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (mem_swap_manager_->trigger_swap()) {
|
||||
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
|
||||
|
@ -437,7 +438,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern
|
|||
}
|
||||
auto output_sizes = kernel_mod.GetOutputSizeList();
|
||||
for (size_t i = 0; i < output_sizes.size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) {
|
||||
return false;
|
||||
|
@ -495,7 +496,7 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN
|
|||
std::vector<size_t> size_list;
|
||||
DeviceAddressPtrList addr_list;
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->ptr_ == nullptr) {
|
||||
is_need_alloc_memory = true;
|
||||
|
@ -520,7 +521,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||
for (size_t i = 0; i < output_sizes.size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->ptr_ == nullptr) {
|
||||
is_need_alloc_memory = true;
|
||||
|
@ -578,7 +579,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
|||
MS_LOG(EXCEPTION) << "Check dynamic reference count failed.";
|
||||
}
|
||||
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
|
||||
mem_manager_->FreeMemFromMemPool(device_address);
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
|
@ -590,7 +591,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
|||
continue;
|
||||
}
|
||||
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
mem_manager_->FreeMemFromMemPool(device_address);
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
|
|
|
@ -228,7 +228,8 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
|
|||
<< AnfAlgo::GetInputTensorNum(kernel);
|
||||
}
|
||||
auto input_node = kernel->input(input_idx + 1);
|
||||
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
|
||||
}
|
||||
|
@ -269,7 +270,8 @@ void MemReuseUtil::SetKernelDefInputs() {
|
|||
if (ref_ptr != nullptr) {
|
||||
// set the inputs of this kernel_def
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||
auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
if (IsPrimitive(input.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
|
||||
}
|
||||
|
|
|
@ -544,9 +544,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &an
|
|||
}
|
||||
|
||||
// get output device addr of anf_node
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) {
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx,
|
||||
bool visit_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (opt::IsNopNode(node)) {
|
||||
if (opt::IsNopNode(node) && visit_nop_node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() == 2) {
|
||||
|
@ -565,9 +566,10 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
|
|||
return addr;
|
||||
}
|
||||
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) {
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
|
||||
bool visit_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (opt::IsNopNode(node)) {
|
||||
if (opt::IsNopNode(node) && visit_nop_node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() == 2) {
|
||||
|
@ -598,14 +600,16 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
|
|||
return kernel_info->OutputAddrExist(output_idx);
|
||||
}
|
||||
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
|
||||
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
|
||||
}
|
||||
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
|
||||
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
|
||||
}
|
||||
|
||||
// set output device addr of anf_node
|
||||
|
|
|
@ -121,14 +121,16 @@ class AnfRuntimeAlgorithm {
|
|||
// get output select data type from prev node,input_index is the input index of current node related to prev node
|
||||
static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output device addr of anf_node
|
||||
static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx);
|
||||
static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
|
||||
// get mutable output device addr of anf_node
|
||||
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx);
|
||||
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
|
||||
// check whether output addr is exist or not
|
||||
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
|
||||
// get address from prev node,input_index is the input index of current node related to prev node
|
||||
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx);
|
||||
static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx);
|
||||
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
|
||||
bool visit_nop_node = true);
|
||||
static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node = true);
|
||||
// set output device addr of anf_node
|
||||
static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
|
||||
// set workspace device addr of anf_node
|
||||
|
|
|
@ -31,6 +31,49 @@ class NetFlatten(nn.Cell):
|
|||
return self.flatten(x)
|
||||
|
||||
|
||||
class NetAllFlatten(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetAllFlatten, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
loop_count = 4
|
||||
while loop_count > 0:
|
||||
x = self.flatten(x)
|
||||
loop_count = loop_count - 1
|
||||
return x
|
||||
|
||||
|
||||
class NetFirstFlatten(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetFirstFlatten, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
loop_count = 4
|
||||
while loop_count > 0:
|
||||
x = self.flatten(x)
|
||||
loop_count = loop_count - 1
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class NetLastFlatten(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetLastFlatten, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
loop_count = 4
|
||||
x = self.relu(x)
|
||||
while loop_count > 0:
|
||||
x = self.flatten(x)
|
||||
loop_count = loop_count - 1
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -46,3 +89,55 @@ def test_flatten():
|
|||
flatten = NetFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_all_flatten():
|
||||
x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32))
|
||||
expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
flatten = NetAllFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
flatten = NetAllFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_first_flatten():
|
||||
x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32))
|
||||
expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
flatten = NetFirstFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
flatten = NetFirstFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_last_flatten():
|
||||
x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32))
|
||||
expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
flatten = NetLastFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
flatten = NetLastFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
Loading…
Reference in New Issue