!69215 [MD] Fix issues on schema of TFRecordDataset

Merge pull request !69215 from xiaotianci/fix_unknown_shape_schema
This commit is contained in:
i-robot 2024-05-10 09:13:23 +00:00 committed by Gitee
commit d45d702a45
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 138 additions and 34 deletions

View File

@ -95,10 +95,14 @@ class ColDescriptor {
/// \return The column's shape
TensorShape Shape() const;
/// \brief getter function
/// \return TF if the column has an assigned fixed shape.
/// \brief Check if the column has a shape.
/// \return Whether the column has a shape.
bool HasShape() const { return tensor_shape_ != nullptr; }
/// \brief Check if the column has a known shape.
/// \return Whether the column has a known shape.
bool HasKnownShape() const { return HasShape() && Shape().known(); }
/// \brief getter function
/// \return The column's tensor implementation type
TensorImpl GetTensorImpl() const { return tensor_impl_; }

View File

@ -649,8 +649,10 @@ Status ParseSingleKnownShapeColumn(const parsed::Feature &feature, std::shared_p
if (bytes_list.size() != num_elements) {
return ReportUnexpectedDataShape(feature_name);
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(bytes_list, TensorShape{static_cast<dsize_t>(num_elements)},
DataType(DataType::DE_STRING), column_tensor));
TensorShape string_tensor_shape = TensorShape::CreateUnknownRankShape();
RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &string_tensor_shape));
RETURN_IF_NOT_OK(
Tensor::CreateFromVector(bytes_list, string_tensor_shape, DataType(DataType::DE_STRING), column_tensor));
} else {
// load string or bytes as uint8 tensor
RETURN_IF_NOT_OK(
@ -750,7 +752,7 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
if (column_descriptor.HasShape()) {
if (column_descriptor.HasKnownShape()) {
if (!column_descriptor.Type().IsString()) {
DataType type;
if (column_descriptor.Type().IsInt() || column_descriptor.Type().IsBool()) {
@ -806,9 +808,10 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
bool type_cast_flag = false;
if (example_dtype != column_descriptor.Type()) {
const std::string msg =
"The data type loaded from the example does not match the predefined type in schema, the actual type: " +
example_dtype.ToString() + ", but the predefined type: " + column_descriptor.Type().ToString();
if (!example_dtype.IsString()) {
"The data type loaded from the example for feature name: " + column_descriptor.Name() +
" does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
", but the predefined type: " + column_descriptor.Type().ToString();
if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
MS_LOG(INFO) << msg << ". This will cause a type cast.";
type_cast_flag = true;
} else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
@ -817,7 +820,7 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
}
}
if (column_descriptor.HasShape()) {
if (column_descriptor.HasKnownShape()) {
RETURN_IF_NOT_OK(ParseSingleKnownShapeColumn(feature, &(*parsed_row)[column_index], feature_name,
column_descriptor, example_dtype));
} else { // if variable length
@ -833,10 +836,9 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
}
for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
if (!feature_already_seen[column_index]) {
RETURN_STATUS_UNEXPECTED("Feature name: " + data_schema_.Column(column_index).Name() +
" is required in schema but could not be found in tfrecord file.");
}
CHECK_FAIL_RETURN_UNEXPECTED(feature_already_seen[column_index],
"Feature name: " + data_schema_.Column(column_index).Name() +
" is required in schema but could not be found in tfrecord file.");
}
parsed_row->setPath(file_paths);
@ -1007,7 +1009,7 @@ Status MergeDenseVarLenMiniBatches(const std::vector<std::vector<VarLenTensorBuf
TensorRow *parsed_row, int32_t column_index, const DataSchema &data_schema,
dsize_t batch_size) {
const ColDescriptor &column_descriptor = data_schema.Column(column_index);
if (column_descriptor.HasShape()) {
if (column_descriptor.HasKnownShape()) {
return Status::OK();
}
std::shared_ptr<Tensor> column_tensor;
@ -1035,7 +1037,7 @@ Status ParseExampleOp::ParallelParseExample(const TensorRow &raw_bytes, TensorRo
std::unordered_map<int32_t, std::vector<std::string>> string_column_map;
for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
if (column_descriptor.HasShape()) {
if (column_descriptor.HasKnownShape()) {
if (!column_descriptor.Type().IsString()) {
auto column_shape = column_descriptor.Shape().InsertDim(0, batch_size);
DataType type;
@ -1139,8 +1141,9 @@ Status ParseSerializedKnownShapeColumn(const parsed::Feature &feature, TensorRow
std::shared_ptr<Tensor> &column_tensor = (*parsed_row)[column_index];
if (example_dtype != column_descriptor.Type()) {
const std::string msg =
"The data type loaded from the example does not match the predefined type in schema, the actual type: " +
example_dtype.ToString() + ", but the predefined type: " + column_descriptor.Type().ToString();
"The data type loaded from the example for feature name: " + column_descriptor.Name() +
" does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
", but the predefined type: " + column_descriptor.Type().ToString();
if (example_dtype == column_tensor->type()) {
// if the actual data type is the same as the pre-allocated tensor,
// we can first read it into the tensor, then cast to the type specified by the schema
@ -1233,9 +1236,10 @@ Status ParseSerializedVarLenColumn(const parsed::Feature &feature, VarLenTensorB
bool type_cast_flag = false;
if (example_dtype != column_descriptor.Type()) {
const std::string msg =
"The data type loaded from the example does not match the predefined type in schema, the actual type: " +
example_dtype.ToString() + ", but the predefined type: " + column_descriptor.Type().ToString();
if (!example_dtype.IsString()) {
"The data type loaded from the example for feature name: " + column_descriptor.Name() +
" does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
", but the predefined type: " + column_descriptor.Type().ToString();
if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
MS_LOG(INFO) << msg << ". This will cause a type cast.";
type_cast_flag = true;
} else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
@ -1355,7 +1359,7 @@ Status ParseExampleOp::ParseSerializedExample(const std::string &example_bytes,
feature_already_seen[column_index] = true;
const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
if (column_descriptor.HasShape()) {
if (column_descriptor.HasKnownShape()) {
RETURN_IF_NOT_OK(ParseSerializedKnownShapeColumn(feature, parsed_row, string_column_map, column_index,
tensor_index, feature_name, column_descriptor, example_dtype));
} else { // if variable length

View File

@ -231,9 +231,46 @@ def test_tfrecord_with_full_schema(do_batch, load_type):
assert dataset.output_shapes() == expected_shape
@pytest.mark.parametrize("do_batch", (True, False))
def test_tfrecord_with_empty_or_unknown_shape_schema(do_batch):
"""
Feature: TFRecordDataset
Description: Test TFRecordDataset with schema while the shape is empty or unknown
Expectation: The data can be processed as expected
"""
schema = ds.Schema()
schema.add_column("col_1d", de_type=mstype.int64, shape=[-1])
schema.add_column("col_2d", de_type=mstype.int64, shape=[2, -1])
schema.add_column("col_3d", de_type=mstype.int64, shape=[2, 2, -1])
schema.add_column("col_binary", de_type=mstype.string, shape=[])
schema.add_column("col_float", de_type=mstype.float32, shape=[-1])
schema.add_column("col_sint16", de_type=mstype.int64, shape=[])
schema.add_column("col_sint32", de_type=mstype.int64, shape=[-1])
schema.add_column("col_sint64", de_type=mstype.int64, shape=[])
schema.add_column("col_sint8", de_type=mstype.int64, shape=[-1])
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
if do_batch:
dataset = dataset.batch(2)
count = 0
for _ in dataset:
count += 1
assert dataset.get_dataset_size() == count
assert dataset.get_col_names() == ["col_1d", "col_2d", "col_3d",
"col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64", "col_sint8"]
assert dataset.output_types() == [np.int64, np.int64, np.int64, np.str_, np.float32, np.int64, np.int64, np.int64,
np.int64]
if do_batch:
expected_shape = [[2, 2], [2, 2, 2], [2, 2, 2, 2], [2], [2, 1], [2], [2, 1], [2], [2, 1]]
else:
expected_shape = [[2], [2, 2], [2, 2, 2], [], [1], [], [1], [], [1]]
assert dataset.output_shapes() == expected_shape
@pytest.mark.parametrize("do_batch", (True, False))
@pytest.mark.parametrize("load_type", ("uint8", "string"))
def test_tfrecord_with_unknown_shape_schema(do_batch, load_type):
def test_tfrecord_with_no_shape_schema(do_batch, load_type):
"""
Feature: TFRecordDataset
Description: Test TFRecordDataset with schema missing feature shape
@ -301,10 +338,10 @@ def test_tfrecord_with_wrong_shape_schema(do_batch):
@pytest.mark.parametrize("do_batch", (True, False))
def test_tfrecord_with_wrong_type_schema(do_batch):
def test_tfrecord_with_mismatch_type_schema(do_batch):
"""
Feature: TFRecordDataset
Description: Test TFRecordDataset with schema containing wrong feature type
Description: Test TFRecordDataset with schema containing mismatch feature type and will cause a type cast
Expectation: The output columns can be converted to the specified type
"""
schema = ds.Schema()
@ -337,6 +374,63 @@ def test_tfrecord_with_wrong_type_schema(do_batch):
assert dataset.output_shapes() == expected_shape
@pytest.mark.parametrize("do_batch", (True, False))
@pytest.mark.parametrize("with_shape", (True, False))
def test_tfrecord_with_wrong_type_schema(do_batch, with_shape):
"""
Feature: TFRecordDataset
Description: Test TFRecordDataset with schema containing wrong feature shape
Expectation: Raise a RuntimeError as expected
"""
schema = ds.Schema()
# feature of type int64 can not be cast to string
if with_shape:
schema.add_column("col_1d", de_type=mstype.string, shape=[2])
else:
schema.add_column("col_1d", de_type=mstype.string)
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
if do_batch:
dataset = dataset.batch(2)
with pytest.raises(RuntimeError) as e:
for _ in dataset:
pass
assert ("The data type loaded from the example for feature name: col_1d does not match the "
"predefined type in schema") in str(e.value)
schema = ds.Schema()
# feature of type float32 can not be cast to string
if with_shape:
schema.add_column("col_float", de_type=mstype.string, shape=[1])
else:
schema.add_column("col_float", de_type=mstype.string)
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
if do_batch:
dataset = dataset.batch(2)
with pytest.raises(RuntimeError) as e:
for _ in dataset:
pass
assert ("The data type loaded from the example for feature name: col_float does not match the "
"predefined type in schema") in str(e.value)
schema = ds.Schema()
# feature of type string can not be cast to int64
if with_shape:
schema.add_column("col_binary", de_type=mstype.int64, shape=[1])
else:
schema.add_column("col_binary", de_type=mstype.int64)
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
if do_batch:
dataset = dataset.batch(2)
with pytest.raises(RuntimeError) as e:
for _ in dataset:
pass
assert ("The data type loaded from the example for feature name: col_binary does not match the "
"predefined type in schema") in str(e.value)
@pytest.mark.parametrize("do_batch", (True, False))
def test_tfrecord_with_column_list(do_batch):
"""
@ -1348,9 +1442,11 @@ if __name__ == '__main__':
test_tfrecord_read_files()
test_tfrecord_multi_files()
test_tfrecord_with_full_schema(True, "string")
test_tfrecord_with_unknown_shape_schema(True, "string")
test_tfrecord_with_empty_or_unknown_shape_schema(True)
test_tfrecord_with_no_shape_schema(True, "string")
test_tfrecord_with_wrong_shape_schema(True)
test_tfrecord_with_wrong_type_schema(True)
test_tfrecord_with_mismatch_type_schema(True)
test_tfrecord_with_wrong_type_schema(True, True)
test_tfrecord_with_column_list(True)
test_tfrecord_without_schema_and_column_list(True)
test_tfrecord_with_both_schema_and_column_list(True, "string")

View File

@ -156,15 +156,15 @@ def test_tfrecord1():
"""
s = ds.Schema()
s.add_column("line", "string", [])
s.add_column("words", "string", [2, 2])
s.add_column("words", "string", [-1])
s.add_column("chinese", "string", [])
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
assert d["line"].shape == (1,)
assert d["line"].shape == line[i].shape
assert d["words"].shape == words[i].shape
assert d["chinese"].shape == (1,)
assert d["chinese"].shape == chinese[i].shape
np.testing.assert_array_equal(line[i], d["line"])
np.testing.assert_array_equal(words[i], d["words"])
np.testing.assert_array_equal(chinese[i], d["chinese"])
@ -195,17 +195,17 @@ def test_tfrecord3():
"""
s = ds.Schema()
s.add_column("line", mstype.string, [])
s.add_column("words", mstype.string, [2, 2])
s.add_column("words", mstype.string, [-1, 2])
s.add_column("chinese", mstype.string, [])
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
assert d["line"].shape == (1,)
assert d["words"].shape == words[i].shape
assert d["chinese"].shape == (1,)
assert d["line"].shape == line[i].shape
assert d["words"].shape == words[i].reshape([2, 2]).shape
assert d["chinese"].shape == chinese[i].shape
np.testing.assert_array_equal(line[i], d["line"])
np.testing.assert_array_equal(words[i], d["words"])
np.testing.assert_array_equal(words[i].reshape([2, 2]), d["words"])
np.testing.assert_array_equal(chinese[i], d["chinese"])