From 27721ff4eaafe4c65a11b11e0a52399269461ceb Mon Sep 17 00:00:00 2001 From: He Wei Date: Sat, 29 Aug 2020 19:05:44 +0800 Subject: [PATCH] Fix Tensor.from_numpy() returns wrong type Tensor.from_numpy() should return mindspore.Tensor not _c_expression.Tensor. --- mindspore/common/tensor.py | 5 +++++ tests/ut/python/ir/test_tensor.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index c007097ca6c..5bacb5741bc 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -232,6 +232,11 @@ class Tensor(Tensor_): raise TypeError("virtual_flag must be bool.") self._virtual_flag = value + @staticmethod + def from_numpy(array): + """Convert numpy array to Tensor without copy data.""" + return Tensor(Tensor_.from_numpy(array)) + def asnumpy(self): """Convert tensor to numpy array.""" return Tensor_.asnumpy(self) diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 09d2f2eaa8a..9ed92b418d7 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -480,6 +480,7 @@ def test_tensor_operation(): def test_tensor_from_numpy(): a = np.ones((2, 3)) t = ms.Tensor.from_numpy(a) + assert isinstance(t, ms.Tensor) assert np.all(t.asnumpy() == 1) # 't' and 'a' share same data. a[1] = 2 @@ -489,3 +490,6 @@ def test_tensor_from_numpy(): del a assert np.all(t.asnumpy()[0] == 1) assert np.all(t.asnumpy()[1] == 2) + with pytest.raises(TypeError): + # incorrect input. + t = ms.Tensor.from_numpy([1, 2, 3])