forked from mindspore-Ecosystem/mindspore
!5534 Fix Tensor.from_numpy() returns wrong type
Merge pull request !5534 from hewei/fix_tensor_from_numpy
This commit is contained in:
commit
4d963d96f4
|
@ -232,6 +232,11 @@ class Tensor(Tensor_):
|
|||
raise TypeError("virtual_flag must be bool.")
|
||||
self._virtual_flag = value
|
||||
|
||||
@staticmethod
|
||||
def from_numpy(array):
|
||||
"""Convert numpy array to Tensor without copy data."""
|
||||
return Tensor(Tensor_.from_numpy(array))
|
||||
|
||||
def asnumpy(self):
|
||||
"""Convert tensor to numpy array."""
|
||||
return Tensor_.asnumpy(self)
|
||||
|
|
|
@ -480,6 +480,7 @@ def test_tensor_operation():
|
|||
def test_tensor_from_numpy():
|
||||
a = np.ones((2, 3))
|
||||
t = ms.Tensor.from_numpy(a)
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert np.all(t.asnumpy() == 1)
|
||||
# 't' and 'a' share same data.
|
||||
a[1] = 2
|
||||
|
@ -489,3 +490,6 @@ def test_tensor_from_numpy():
|
|||
del a
|
||||
assert np.all(t.asnumpy()[0] == 1)
|
||||
assert np.all(t.asnumpy()[1] == 2)
|
||||
with pytest.raises(TypeError):
|
||||
# incorrect input.
|
||||
t = ms.Tensor.from_numpy([1, 2, 3])
|
||||
|
|
Loading…
Reference in New Issue