!5534 Fix Tensor.from_numpy() returns wrong type

Merge pull request !5534 from hewei/fix_tensor_from_numpy
This commit is contained in:
mindspore-ci-bot 2020-08-31 09:15:48 +08:00 committed by Gitee
commit 4d963d96f4
2 changed files with 9 additions and 0 deletions

View File

@ -232,6 +232,11 @@ class Tensor(Tensor_):
raise TypeError("virtual_flag must be bool.") raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value self._virtual_flag = value
@staticmethod
def from_numpy(array):
"""Convert numpy array to Tensor without copy data."""
return Tensor(Tensor_.from_numpy(array))
def asnumpy(self): def asnumpy(self):
"""Convert tensor to numpy array.""" """Convert tensor to numpy array."""
return Tensor_.asnumpy(self) return Tensor_.asnumpy(self)

View File

@ -480,6 +480,7 @@ def test_tensor_operation():
def test_tensor_from_numpy(): def test_tensor_from_numpy():
a = np.ones((2, 3)) a = np.ones((2, 3))
t = ms.Tensor.from_numpy(a) t = ms.Tensor.from_numpy(a)
assert isinstance(t, ms.Tensor)
assert np.all(t.asnumpy() == 1) assert np.all(t.asnumpy() == 1)
# 't' and 'a' share same data. # 't' and 'a' share same data.
a[1] = 2 a[1] = 2
@ -489,3 +490,6 @@ def test_tensor_from_numpy():
del a del a
assert np.all(t.asnumpy()[0] == 1) assert np.all(t.asnumpy()[0] == 1)
assert np.all(t.asnumpy()[1] == 2) assert np.all(t.asnumpy()[1] == 2)
with pytest.raises(TypeError):
# incorrect input.
t = ms.Tensor.from_numpy([1, 2, 3])