fix bug of RefNode output

This commit is contained in:
caifubi 2021-12-25 15:50:09 +08:00
parent d5c63e4381
commit f1856c83ac
2 changed files with 45 additions and 3 deletions

View File

@ -180,9 +180,14 @@ void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
// Just compare shared_ptr of two DeviceAddress. // Just compare shared_ptr of two DeviceAddress.
// The ptr of DeviceAddress may still be nullptr. // The ptr of DeviceAddress may still be nullptr.
if (input_addr != ref_node_output_addr) { if (input_addr != ref_node_output_addr) {
// The output of RefNode will not be used by subsequent Node. // AnfAlgo::SetOutputAddr cannot update the device_address of frontend Tensor
// So update the DeviceAddress of the kernel directly instead of updating the ptr of the DeviceAddress. // if the output of RefNode is used by subsequent nodes.
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get()); // Because the frontend Tensor is copied from backend Tensor and the shared_ptr of Tensor is different.
if (input_addr->GetMutablePtr() == nullptr) {
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
} else {
ref_node_output_addr->set_ptr(input_addr->GetMutablePtr());
}
} }
} }
} }

View File

@ -0,0 +1,37 @@
import pytest
import numpy as np
from mindspore import ops, Tensor, context
from mindspore.common.parameter import Parameter
from mindspore.nn import Cell
class AssignNet(Cell):
def __init__(self, input_variable):
super(AssignNet, self).__init__()
self.op = ops.Assign()
self.input_data = input_variable
def construct(self, input_x):
return self.op(self.input_data, input_x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_assign_as_output():
"""
Feature: PyNative MindRT
Description: Test PyNative MindRT RefNode.
Expectation: No exception.
"""
np.random.seed(0)
input_np = np.random.randn(5, 5).astype(dtype=np.int32)
context.set_context(mode=context.PYNATIVE_MODE)
input_variable = Parameter(Tensor(np.random.randn(5, 5).astype(dtype=np.float32)))
input_x = Tensor(input_np)
net = AssignNet(input_variable)
out = net(input_x)
assert input_np.all() == out.asnumpy().astype(dtype=np.int32).all()