!20725 Support the tensor of string dtype.

Merge pull request !20725 from 张清华/opt
This commit is contained in:
i-robot 2021-07-23 02:12:48 +00:00 committed by Gitee
commit e8448f972b
5 changed files with 24 additions and 10 deletions

View File

@ -78,8 +78,11 @@ static TypeId GetDataType(const py::buffer_info &buf) {
case '?':
return TypeId::kNumberTypeBool;
}
} else if (buf.format.size() >= 2 && buf.format.back() == 'w') {
// Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.
return TypeId::kObjectTypeString;
}
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize;
return TypeId::kTypeUnknown;
}
@ -109,6 +112,8 @@ static std::string GetPyTypeFormat(TypeId data_type) {
return py::format_descriptor<int64_t>::format();
case TypeId::kNumberTypeBool:
return py::format_descriptor<bool>::format();
case TypeId::kObjectTypeString:
return py::format_descriptor<uint8_t>::format();
default:
MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
return "";
@ -181,6 +186,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) {
MS_LOG(EXCEPTION) << "Unsupported tensor type!";
}
MS_LOG(DEBUG) << "data_type: " << data_type << ", buf_type: " << buf_type;
if (data_type == TypeId::kObjectTypeString || buf_type == TypeId::kObjectTypeString) {
return TensorPy::MakeTensorOfNumpy(input);
}
// Use buf type as data type if type_ptr not set.
if (data_type == TypeId::kTypeUnknown) {
data_type = buf_type;
@ -210,7 +219,7 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
}
/// Creates a Tensor from a numpy array without copy
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
TensorPtr TensorPy::MakeTensorOfNumpy(const py::array &input) {
// Check format.
if (!IsCContiguous(input)) {
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
@ -504,7 +513,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.strides
(4, 4)
)mydelimiter")
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
.def("from_numpy", TensorPy::MakeTensorOfNumpy, R"mydelimiter(
Creates a Tensor from a numpy.ndarray without copy.
Arg:

View File

@ -102,7 +102,7 @@ class TensorPy {
// brief Create Tensor from a numpy array without copy.
//
// param input [py::array] Data value of the tensor.
static TensorPtr MakeTensorNoCopy(const py::array &input);
static TensorPtr MakeTensorOfNumpy(const py::array &input);
static py::array SyncAsNumpy(const Tensor &tensor);

View File

@ -91,8 +91,9 @@ class Tensor(Tensor_):
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
np.float16, np.float32, np.float64, np.bool_, np.str_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
input_data.dtype.kind != 'U': # Support dtype np.str_
raise TypeError(f"For Tensor, the input_data is a numpy array, "
f"but it's data type: {input_data.dtype} is not in supported list:\
{list(i.__name__ for i in valid_dtypes)}.")
@ -100,7 +101,7 @@ class Tensor(Tensor_):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
if dtype is not None:
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor")
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data)

View File

@ -149,6 +149,10 @@ std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId
auto buf = static_cast<double *>(data);
return NewData<T>(buf, size);
}
case kObjectTypeString: {
auto buf = static_cast<uint8_t *>(data);
return NewData<T>(buf, size);
}
default:
break;
}

View File

@ -320,9 +320,9 @@ def test_tensor_input_empty():
def test_tensor_input_ndarray_str():
with pytest.raises(TypeError):
inp = np.array(["88", 2, 4])
ms.Tensor(inp)
inp = np.array(["88", 0, 9])
tensor = ms.Tensor(inp)
assert str(tensor) == "['88' '0' '9']"
def test_tensor_input_ndarray_bool():