forked from mindspore-Ecosystem/mindspore
fix bug of RefNode output
This commit is contained in:
parent
d5c63e4381
commit
f1856c83ac
|
@ -180,9 +180,14 @@ void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
|
||||||
// Just compare shared_ptr of two DeviceAddress.
|
// Just compare shared_ptr of two DeviceAddress.
|
||||||
// The ptr of DeviceAddress may still be nullptr.
|
// The ptr of DeviceAddress may still be nullptr.
|
||||||
if (input_addr != ref_node_output_addr) {
|
if (input_addr != ref_node_output_addr) {
|
||||||
// The output of RefNode will not be used by subsequent Node.
|
// AnfAlgo::SetOutputAddr cannot update the device_address of frontend Tensor
|
||||||
// So update the DeviceAddress of the kernel directly instead of updating the ptr of the DeviceAddress.
|
// if the output of RefNode is used by subsequent nodes.
|
||||||
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
|
// Because the frontend Tensor is copied from backend Tensor and the shared_ptr of Tensor is different.
|
||||||
|
if (input_addr->GetMutablePtr() == nullptr) {
|
||||||
|
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
|
||||||
|
} else {
|
||||||
|
ref_node_output_addr->set_ptr(input_addr->GetMutablePtr());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import ops, Tensor, context
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
|
||||||
|
class AssignNet(Cell):
|
||||||
|
def __init__(self, input_variable):
|
||||||
|
super(AssignNet, self).__init__()
|
||||||
|
self.op = ops.Assign()
|
||||||
|
self.input_data = input_variable
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
return self.op(self.input_data, input_x)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_assign_as_output():
|
||||||
|
"""
|
||||||
|
Feature: PyNative MindRT
|
||||||
|
Description: Test PyNative MindRT RefNode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
input_np = np.random.randn(5, 5).astype(dtype=np.int32)
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
input_variable = Parameter(Tensor(np.random.randn(5, 5).astype(dtype=np.float32)))
|
||||||
|
input_x = Tensor(input_np)
|
||||||
|
net = AssignNet(input_variable)
|
||||||
|
out = net(input_x)
|
||||||
|
assert input_np.all() == out.asnumpy().astype(dtype=np.int32).all()
|
Loading…
Reference in New Issue