diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc index dccdbb65b8c..c06ba2a8203 100644 --- a/mindspore/ccsrc/ir/tensor.cc +++ b/mindspore/ccsrc/ir/tensor.cc @@ -30,8 +30,6 @@ namespace mindspore { namespace tensor { -using Bool = unsigned char; - static std::string MakeId() { // Use atomic to make id generator thread safe. static std::atomic last_id{1}; @@ -50,10 +48,7 @@ template std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { const size_t count = SizeOf(shape); switch (data_type) { - case kNumberTypeBool: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } + case kNumberTypeBool: case kNumberTypeUInt8: { auto buf = static_cast(data); return std::vector(buf, buf + count); @@ -104,14 +99,6 @@ std::vector CopyData(const std::vector &shape, void *data, TypeId data_t MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; } -// Convert to bool is not allowed. -template <> -std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { - MS_LOG(EXCEPTION) << "Cannot convert from " << TypeIdLabel(data_type) << " to " << TypeIdLabel(kNumberTypeBool) - << "."; - return {}; -} - template std::vector CopyData(const std::vector &shape, void *data, size_t data_len) { size_t size = SizeOf(shape); @@ -192,10 +179,6 @@ template TensorDataPtr MakeTensorData(TypeId data_type, const std::vector &shape, Args... args) { switch (data_type) { case kNumberTypeBool: - // std::vector is a specialization of std::vector, - // it may use single bit instead of sizeof(bool) bytes, - // so we use std::vector for bool tensors. - return std::make_shared>(shape, args...); case kNumberTypeUInt8: return std::make_shared>(shape, args...); case kNumberTypeInt8: diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 357187cdddf..72100b2715e 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -430,10 +430,20 @@ def test_tensor_dtype_np_int64(): def test_tensor_dtype_fp32_to_bool(): - with pytest.raises(RuntimeError): - input_ = np.random.randn(2, 3, 4, 5).astype(np.float32) - input_ = ms.Tensor(input_) - _ = ms.Tensor(input_, dtype=ms.bool_) + input_ = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_ = ms.Tensor(input_) + t = ms.Tensor(input_, dtype=ms.bool_) + assert isinstance(t, ms.Tensor) + assert t.shape == (2, 3, 4, 5) + assert t.dtype == ms.bool_ + + +def test_tensor_dtype_fp64_to_uint8(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) + t = ms.Tensor(array, ms.uint8) + assert isinstance(t, ms.Tensor) + assert t.shape == (2, 3) + assert t.dtype == ms.uint8 def test_tensor_operation():