forked from mindspore-Ecosystem/mindspore
!20725 Support the tensor of string dtype.
Merge pull request !20725 from 张清华/opt
This commit is contained in:
commit
e8448f972b
|
@ -78,8 +78,11 @@ static TypeId GetDataType(const py::buffer_info &buf) {
|
|||
case '?':
|
||||
return TypeId::kNumberTypeBool;
|
||||
}
|
||||
} else if (buf.format.size() >= 2 && buf.format.back() == 'w') {
|
||||
// Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.
|
||||
return TypeId::kObjectTypeString;
|
||||
}
|
||||
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
|
||||
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize;
|
||||
return TypeId::kTypeUnknown;
|
||||
}
|
||||
|
||||
|
@ -109,6 +112,8 @@ static std::string GetPyTypeFormat(TypeId data_type) {
|
|||
return py::format_descriptor<int64_t>::format();
|
||||
case TypeId::kNumberTypeBool:
|
||||
return py::format_descriptor<bool>::format();
|
||||
case TypeId::kObjectTypeString:
|
||||
return py::format_descriptor<uint8_t>::format();
|
||||
default:
|
||||
MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
|
||||
return "";
|
||||
|
@ -181,6 +186,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
|
|||
if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported tensor type!";
|
||||
}
|
||||
MS_LOG(DEBUG) << "data_type: " << data_type << ", buf_type: " << buf_type;
|
||||
if (data_type == TypeId::kObjectTypeString || buf_type == TypeId::kObjectTypeString) {
|
||||
return TensorPy::MakeTensorOfNumpy(input);
|
||||
}
|
||||
// Use buf type as data type if type_ptr not set.
|
||||
if (data_type == TypeId::kTypeUnknown) {
|
||||
data_type = buf_type;
|
||||
|
@ -210,7 +219,7 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
|
|||
}
|
||||
|
||||
/// Creates a Tensor from a numpy array without copy
|
||||
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
|
||||
TensorPtr TensorPy::MakeTensorOfNumpy(const py::array &input) {
|
||||
// Check format.
|
||||
if (!IsCContiguous(input)) {
|
||||
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
|
||||
|
@ -504,7 +513,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.strides
|
||||
(4, 4)
|
||||
)mydelimiter")
|
||||
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
|
||||
.def("from_numpy", TensorPy::MakeTensorOfNumpy, R"mydelimiter(
|
||||
Creates a Tensor from a numpy.ndarray without copy.
|
||||
|
||||
Arg:
|
||||
|
|
|
@ -102,7 +102,7 @@ class TensorPy {
|
|||
// brief Create Tensor from a numpy array without copy.
|
||||
//
|
||||
// param input [py::array] Data value of the tensor.
|
||||
static TensorPtr MakeTensorNoCopy(const py::array &input);
|
||||
static TensorPtr MakeTensorOfNumpy(const py::array &input);
|
||||
|
||||
static py::array SyncAsNumpy(const Tensor &tensor);
|
||||
|
||||
|
|
|
@ -91,8 +91,9 @@ class Tensor(Tensor_):
|
|||
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
|
||||
'Tensor')
|
||||
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64, np.bool_)
|
||||
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
|
||||
np.float16, np.float32, np.float64, np.bool_, np.str_)
|
||||
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
||||
input_data.dtype.kind != 'U': # Support dtype np.str_
|
||||
raise TypeError(f"For Tensor, the input_data is a numpy array, "
|
||||
f"but it's data type: {input_data.dtype} is not in supported list:\
|
||||
{list(i.__name__ for i in valid_dtypes)}.")
|
||||
|
@ -100,7 +101,7 @@ class Tensor(Tensor_):
|
|||
if np.array(input_data).dtype not in valid_dtypes:
|
||||
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
|
||||
if dtype is not None:
|
||||
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")
|
||||
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor")
|
||||
|
||||
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
|
||||
input_data = np.ascontiguousarray(input_data)
|
||||
|
|
|
@ -149,6 +149,10 @@ std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId
|
|||
auto buf = static_cast<double *>(data);
|
||||
return NewData<T>(buf, size);
|
||||
}
|
||||
case kObjectTypeString: {
|
||||
auto buf = static_cast<uint8_t *>(data);
|
||||
return NewData<T>(buf, size);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -320,9 +320,9 @@ def test_tensor_input_empty():
|
|||
|
||||
|
||||
def test_tensor_input_ndarray_str():
|
||||
with pytest.raises(TypeError):
|
||||
inp = np.array(["88", 2, 4])
|
||||
ms.Tensor(inp)
|
||||
inp = np.array(["88", 0, 9])
|
||||
tensor = ms.Tensor(inp)
|
||||
assert str(tensor) == "['88' '0' '9']"
|
||||
|
||||
|
||||
def test_tensor_input_ndarray_bool():
|
||||
|
|
Loading…
Reference in New Issue