forked from mindspore-Ecosystem/mindspore
!5534 Fix Tensor.from_numpy() returns wrong type
Merge pull request !5534 from hewei/fix_tensor_from_numpy
This commit is contained in:
commit
4d963d96f4
|
@ -232,6 +232,11 @@ class Tensor(Tensor_):
|
||||||
raise TypeError("virtual_flag must be bool.")
|
raise TypeError("virtual_flag must be bool.")
|
||||||
self._virtual_flag = value
|
self._virtual_flag = value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_numpy(array):
|
||||||
|
"""Convert numpy array to Tensor without copy data."""
|
||||||
|
return Tensor(Tensor_.from_numpy(array))
|
||||||
|
|
||||||
def asnumpy(self):
|
def asnumpy(self):
|
||||||
"""Convert tensor to numpy array."""
|
"""Convert tensor to numpy array."""
|
||||||
return Tensor_.asnumpy(self)
|
return Tensor_.asnumpy(self)
|
||||||
|
|
|
@ -480,6 +480,7 @@ def test_tensor_operation():
|
||||||
def test_tensor_from_numpy():
|
def test_tensor_from_numpy():
|
||||||
a = np.ones((2, 3))
|
a = np.ones((2, 3))
|
||||||
t = ms.Tensor.from_numpy(a)
|
t = ms.Tensor.from_numpy(a)
|
||||||
|
assert isinstance(t, ms.Tensor)
|
||||||
assert np.all(t.asnumpy() == 1)
|
assert np.all(t.asnumpy() == 1)
|
||||||
# 't' and 'a' share same data.
|
# 't' and 'a' share same data.
|
||||||
a[1] = 2
|
a[1] = 2
|
||||||
|
@ -489,3 +490,6 @@ def test_tensor_from_numpy():
|
||||||
del a
|
del a
|
||||||
assert np.all(t.asnumpy()[0] == 1)
|
assert np.all(t.asnumpy()[0] == 1)
|
||||||
assert np.all(t.asnumpy()[1] == 2)
|
assert np.all(t.asnumpy()[1] == 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# incorrect input.
|
||||||
|
t = ms.Tensor.from_numpy([1, 2, 3])
|
||||||
|
|
Loading…
Reference in New Issue