forked from mindspore-Ecosystem/mindspore
Fix bool type tensor problem
Problem: Create 'uint8' type tensor from float data raise RuntimeError; Fix: Use same internal data type for 'uint8' and 'bool' tensor and allows convert float data to bool tensor.
This commit is contained in:
parent
622b97f3b6
commit
a0623e1528
|
@ -30,8 +30,6 @@
|
|||
namespace mindspore {
|
||||
namespace tensor {
|
||||
|
||||
using Bool = unsigned char;
|
||||
|
||||
static std::string MakeId() {
|
||||
// Use atomic to make id generator thread safe.
|
||||
static std::atomic<uint64_t> last_id{1};
|
||||
|
@ -50,10 +48,7 @@ template <typename T>
|
|||
std::vector<T> CopyData(const std::vector<int> &shape, void *data, TypeId data_type) {
|
||||
const size_t count = SizeOf(shape);
|
||||
switch (data_type) {
|
||||
case kNumberTypeBool: {
|
||||
auto buf = static_cast<Bool *>(data);
|
||||
return std::vector<T>(buf, buf + count);
|
||||
}
|
||||
case kNumberTypeBool:
|
||||
case kNumberTypeUInt8: {
|
||||
auto buf = static_cast<uint8_t *>(data);
|
||||
return std::vector<T>(buf, buf + count);
|
||||
|
@ -104,14 +99,6 @@ std::vector<T> CopyData(const std::vector<int> &shape, void *data, TypeId data_t
|
|||
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
|
||||
}
|
||||
|
||||
// Convert to bool is not allowed.
|
||||
template <>
|
||||
std::vector<Bool> CopyData<Bool>(const std::vector<int> &shape, void *data, TypeId data_type) {
|
||||
MS_LOG(EXCEPTION) << "Cannot convert from " << TypeIdLabel(data_type) << " to " << TypeIdLabel(kNumberTypeBool)
|
||||
<< ".";
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> CopyData(const std::vector<int> &shape, void *data, size_t data_len) {
|
||||
size_t size = SizeOf(shape);
|
||||
|
@ -192,10 +179,6 @@ template <typename... Args>
|
|||
TensorDataPtr MakeTensorData(TypeId data_type, const std::vector<int> &shape, Args... args) {
|
||||
switch (data_type) {
|
||||
case kNumberTypeBool:
|
||||
// std::vector<bool> is a specialization of std::vector,
|
||||
// it may use single bit instead of sizeof(bool) bytes,
|
||||
// so we use std::vector<Bool> for bool tensors.
|
||||
return std::make_shared<TensorDataImpl<Bool>>(shape, args...);
|
||||
case kNumberTypeUInt8:
|
||||
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
|
||||
case kNumberTypeInt8:
|
||||
|
|
|
@ -430,10 +430,20 @@ def test_tensor_dtype_np_int64():
|
|||
|
||||
|
||||
def test_tensor_dtype_fp32_to_bool():
|
||||
with pytest.raises(RuntimeError):
|
||||
input_ = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_ = ms.Tensor(input_)
|
||||
_ = ms.Tensor(input_, dtype=ms.bool_)
|
||||
input_ = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_ = ms.Tensor(input_)
|
||||
t = ms.Tensor(input_, dtype=ms.bool_)
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert t.shape == (2, 3, 4, 5)
|
||||
assert t.dtype == ms.bool_
|
||||
|
||||
|
||||
def test_tensor_dtype_fp64_to_uint8():
|
||||
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
|
||||
t = ms.Tensor(array, ms.uint8)
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert t.shape == (2, 3)
|
||||
assert t.dtype == ms.uint8
|
||||
|
||||
|
||||
def test_tensor_operation():
|
||||
|
|
Loading…
Reference in New Issue