!69215 [MD] Fix issues on schema of TFRecordDataset
Merge pull request !69215 from xiaotianci/fix_unknown_shape_schema
This commit is contained in:
commit
d45d702a45
|
@ -95,10 +95,14 @@ class ColDescriptor {
|
|||
/// \return The column's shape
|
||||
TensorShape Shape() const;
|
||||
|
||||
/// \brief getter function
|
||||
/// \return TF if the column has an assigned fixed shape.
|
||||
/// \brief Check if the column has a shape.
|
||||
/// \return Whether the column has a shape.
|
||||
bool HasShape() const { return tensor_shape_ != nullptr; }
|
||||
|
||||
/// \brief Check if the column has a known shape.
|
||||
/// \return Whether the column has a known shape.
|
||||
bool HasKnownShape() const { return HasShape() && Shape().known(); }
|
||||
|
||||
/// \brief getter function
|
||||
/// \return The column's tensor implementation type
|
||||
TensorImpl GetTensorImpl() const { return tensor_impl_; }
|
||||
|
|
|
@ -649,8 +649,10 @@ Status ParseSingleKnownShapeColumn(const parsed::Feature &feature, std::shared_p
|
|||
if (bytes_list.size() != num_elements) {
|
||||
return ReportUnexpectedDataShape(feature_name);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(bytes_list, TensorShape{static_cast<dsize_t>(num_elements)},
|
||||
DataType(DataType::DE_STRING), column_tensor));
|
||||
TensorShape string_tensor_shape = TensorShape::CreateUnknownRankShape();
|
||||
RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &string_tensor_shape));
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateFromVector(bytes_list, string_tensor_shape, DataType(DataType::DE_STRING), column_tensor));
|
||||
} else {
|
||||
// load string or bytes as uint8 tensor
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -750,7 +752,7 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
|
|||
|
||||
for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
|
||||
const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
|
||||
if (column_descriptor.HasShape()) {
|
||||
if (column_descriptor.HasKnownShape()) {
|
||||
if (!column_descriptor.Type().IsString()) {
|
||||
DataType type;
|
||||
if (column_descriptor.Type().IsInt() || column_descriptor.Type().IsBool()) {
|
||||
|
@ -806,9 +808,10 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
|
|||
bool type_cast_flag = false;
|
||||
if (example_dtype != column_descriptor.Type()) {
|
||||
const std::string msg =
|
||||
"The data type loaded from the example does not match the predefined type in schema, the actual type: " +
|
||||
example_dtype.ToString() + ", but the predefined type: " + column_descriptor.Type().ToString();
|
||||
if (!example_dtype.IsString()) {
|
||||
"The data type loaded from the example for feature name: " + column_descriptor.Name() +
|
||||
" does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
|
||||
", but the predefined type: " + column_descriptor.Type().ToString();
|
||||
if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
|
||||
MS_LOG(INFO) << msg << ". This will cause a type cast.";
|
||||
type_cast_flag = true;
|
||||
} else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
|
||||
|
@ -817,7 +820,7 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
|
|||
}
|
||||
}
|
||||
|
||||
if (column_descriptor.HasShape()) {
|
||||
if (column_descriptor.HasKnownShape()) {
|
||||
RETURN_IF_NOT_OK(ParseSingleKnownShapeColumn(feature, &(*parsed_row)[column_index], feature_name,
|
||||
column_descriptor, example_dtype));
|
||||
} else { // if variable length
|
||||
|
@ -833,10 +836,9 @@ Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow
|
|||
}
|
||||
|
||||
for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
|
||||
if (!feature_already_seen[column_index]) {
|
||||
RETURN_STATUS_UNEXPECTED("Feature name: " + data_schema_.Column(column_index).Name() +
|
||||
" is required in schema but could not be found in tfrecord file.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(feature_already_seen[column_index],
|
||||
"Feature name: " + data_schema_.Column(column_index).Name() +
|
||||
" is required in schema but could not be found in tfrecord file.");
|
||||
}
|
||||
|
||||
parsed_row->setPath(file_paths);
|
||||
|
@ -1007,7 +1009,7 @@ Status MergeDenseVarLenMiniBatches(const std::vector<std::vector<VarLenTensorBuf
|
|||
TensorRow *parsed_row, int32_t column_index, const DataSchema &data_schema,
|
||||
dsize_t batch_size) {
|
||||
const ColDescriptor &column_descriptor = data_schema.Column(column_index);
|
||||
if (column_descriptor.HasShape()) {
|
||||
if (column_descriptor.HasKnownShape()) {
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<Tensor> column_tensor;
|
||||
|
@ -1035,7 +1037,7 @@ Status ParseExampleOp::ParallelParseExample(const TensorRow &raw_bytes, TensorRo
|
|||
std::unordered_map<int32_t, std::vector<std::string>> string_column_map;
|
||||
for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
|
||||
const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
|
||||
if (column_descriptor.HasShape()) {
|
||||
if (column_descriptor.HasKnownShape()) {
|
||||
if (!column_descriptor.Type().IsString()) {
|
||||
auto column_shape = column_descriptor.Shape().InsertDim(0, batch_size);
|
||||
DataType type;
|
||||
|
@ -1139,8 +1141,9 @@ Status ParseSerializedKnownShapeColumn(const parsed::Feature &feature, TensorRow
|
|||
std::shared_ptr<Tensor> &column_tensor = (*parsed_row)[column_index];
|
||||
if (example_dtype != column_descriptor.Type()) {
|
||||
const std::string msg =
|
||||
"The data type loaded from the example does not match the predefined type in schema, the actual type: " +
|
||||
example_dtype.ToString() + ", but the predefined type: " + column_descriptor.Type().ToString();
|
||||
"The data type loaded from the example for feature name: " + column_descriptor.Name() +
|
||||
" does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
|
||||
", but the predefined type: " + column_descriptor.Type().ToString();
|
||||
if (example_dtype == column_tensor->type()) {
|
||||
// if the actual data type is the same as the pre-allocated tensor,
|
||||
// we can first read it into the tensor, then cast to the type specified by the schema
|
||||
|
@ -1233,9 +1236,10 @@ Status ParseSerializedVarLenColumn(const parsed::Feature &feature, VarLenTensorB
|
|||
bool type_cast_flag = false;
|
||||
if (example_dtype != column_descriptor.Type()) {
|
||||
const std::string msg =
|
||||
"The data type loaded from the example does not match the predefined type in schema, the actual type: " +
|
||||
example_dtype.ToString() + ", but the predefined type: " + column_descriptor.Type().ToString();
|
||||
if (!example_dtype.IsString()) {
|
||||
"The data type loaded from the example for feature name: " + column_descriptor.Name() +
|
||||
" does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
|
||||
", but the predefined type: " + column_descriptor.Type().ToString();
|
||||
if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
|
||||
MS_LOG(INFO) << msg << ". This will cause a type cast.";
|
||||
type_cast_flag = true;
|
||||
} else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
|
||||
|
@ -1355,7 +1359,7 @@ Status ParseExampleOp::ParseSerializedExample(const std::string &example_bytes,
|
|||
feature_already_seen[column_index] = true;
|
||||
|
||||
const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
|
||||
if (column_descriptor.HasShape()) {
|
||||
if (column_descriptor.HasKnownShape()) {
|
||||
RETURN_IF_NOT_OK(ParseSerializedKnownShapeColumn(feature, parsed_row, string_column_map, column_index,
|
||||
tensor_index, feature_name, column_descriptor, example_dtype));
|
||||
} else { // if variable length
|
||||
|
|
|
@ -231,9 +231,46 @@ def test_tfrecord_with_full_schema(do_batch, load_type):
|
|||
assert dataset.output_shapes() == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_batch", (True, False))
|
||||
def test_tfrecord_with_empty_or_unknown_shape_schema(do_batch):
|
||||
"""
|
||||
Feature: TFRecordDataset
|
||||
Description: Test TFRecordDataset with schema while the shape is empty or unknown
|
||||
Expectation: The data can be processed as expected
|
||||
"""
|
||||
schema = ds.Schema()
|
||||
schema.add_column("col_1d", de_type=mstype.int64, shape=[-1])
|
||||
schema.add_column("col_2d", de_type=mstype.int64, shape=[2, -1])
|
||||
schema.add_column("col_3d", de_type=mstype.int64, shape=[2, 2, -1])
|
||||
schema.add_column("col_binary", de_type=mstype.string, shape=[])
|
||||
schema.add_column("col_float", de_type=mstype.float32, shape=[-1])
|
||||
schema.add_column("col_sint16", de_type=mstype.int64, shape=[])
|
||||
schema.add_column("col_sint32", de_type=mstype.int64, shape=[-1])
|
||||
schema.add_column("col_sint64", de_type=mstype.int64, shape=[])
|
||||
schema.add_column("col_sint8", de_type=mstype.int64, shape=[-1])
|
||||
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
|
||||
if do_batch:
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
count = 0
|
||||
for _ in dataset:
|
||||
count += 1
|
||||
assert dataset.get_dataset_size() == count
|
||||
assert dataset.get_col_names() == ["col_1d", "col_2d", "col_3d",
|
||||
"col_binary", "col_float",
|
||||
"col_sint16", "col_sint32", "col_sint64", "col_sint8"]
|
||||
assert dataset.output_types() == [np.int64, np.int64, np.int64, np.str_, np.float32, np.int64, np.int64, np.int64,
|
||||
np.int64]
|
||||
if do_batch:
|
||||
expected_shape = [[2, 2], [2, 2, 2], [2, 2, 2, 2], [2], [2, 1], [2], [2, 1], [2], [2, 1]]
|
||||
else:
|
||||
expected_shape = [[2], [2, 2], [2, 2, 2], [], [1], [], [1], [], [1]]
|
||||
assert dataset.output_shapes() == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_batch", (True, False))
|
||||
@pytest.mark.parametrize("load_type", ("uint8", "string"))
|
||||
def test_tfrecord_with_unknown_shape_schema(do_batch, load_type):
|
||||
def test_tfrecord_with_no_shape_schema(do_batch, load_type):
|
||||
"""
|
||||
Feature: TFRecordDataset
|
||||
Description: Test TFRecordDataset with schema missing feature shape
|
||||
|
@ -301,10 +338,10 @@ def test_tfrecord_with_wrong_shape_schema(do_batch):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("do_batch", (True, False))
|
||||
def test_tfrecord_with_wrong_type_schema(do_batch):
|
||||
def test_tfrecord_with_mismatch_type_schema(do_batch):
|
||||
"""
|
||||
Feature: TFRecordDataset
|
||||
Description: Test TFRecordDataset with schema containing wrong feature type
|
||||
Description: Test TFRecordDataset with schema containing mismatch feature type and will cause a type cast
|
||||
Expectation: The output columns can be converted to the specified type
|
||||
"""
|
||||
schema = ds.Schema()
|
||||
|
@ -337,6 +374,63 @@ def test_tfrecord_with_wrong_type_schema(do_batch):
|
|||
assert dataset.output_shapes() == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_batch", (True, False))
|
||||
@pytest.mark.parametrize("with_shape", (True, False))
|
||||
def test_tfrecord_with_wrong_type_schema(do_batch, with_shape):
|
||||
"""
|
||||
Feature: TFRecordDataset
|
||||
Description: Test TFRecordDataset with schema containing wrong feature shape
|
||||
Expectation: Raise a RuntimeError as expected
|
||||
"""
|
||||
schema = ds.Schema()
|
||||
# feature of type int64 can not be cast to string
|
||||
if with_shape:
|
||||
schema.add_column("col_1d", de_type=mstype.string, shape=[2])
|
||||
else:
|
||||
schema.add_column("col_1d", de_type=mstype.string)
|
||||
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
|
||||
if do_batch:
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
for _ in dataset:
|
||||
pass
|
||||
assert ("The data type loaded from the example for feature name: col_1d does not match the "
|
||||
"predefined type in schema") in str(e.value)
|
||||
|
||||
schema = ds.Schema()
|
||||
# feature of type float32 can not be cast to string
|
||||
if with_shape:
|
||||
schema.add_column("col_float", de_type=mstype.string, shape=[1])
|
||||
else:
|
||||
schema.add_column("col_float", de_type=mstype.string)
|
||||
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
|
||||
if do_batch:
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
for _ in dataset:
|
||||
pass
|
||||
assert ("The data type loaded from the example for feature name: col_float does not match the "
|
||||
"predefined type in schema") in str(e.value)
|
||||
|
||||
schema = ds.Schema()
|
||||
# feature of type string can not be cast to int64
|
||||
if with_shape:
|
||||
schema.add_column("col_binary", de_type=mstype.int64, shape=[1])
|
||||
else:
|
||||
schema.add_column("col_binary", de_type=mstype.int64)
|
||||
dataset = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
|
||||
if do_batch:
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
for _ in dataset:
|
||||
pass
|
||||
assert ("The data type loaded from the example for feature name: col_binary does not match the "
|
||||
"predefined type in schema") in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_batch", (True, False))
|
||||
def test_tfrecord_with_column_list(do_batch):
|
||||
"""
|
||||
|
@ -1348,9 +1442,11 @@ if __name__ == '__main__':
|
|||
test_tfrecord_read_files()
|
||||
test_tfrecord_multi_files()
|
||||
test_tfrecord_with_full_schema(True, "string")
|
||||
test_tfrecord_with_unknown_shape_schema(True, "string")
|
||||
test_tfrecord_with_empty_or_unknown_shape_schema(True)
|
||||
test_tfrecord_with_no_shape_schema(True, "string")
|
||||
test_tfrecord_with_wrong_shape_schema(True)
|
||||
test_tfrecord_with_wrong_type_schema(True)
|
||||
test_tfrecord_with_mismatch_type_schema(True)
|
||||
test_tfrecord_with_wrong_type_schema(True, True)
|
||||
test_tfrecord_with_column_list(True)
|
||||
test_tfrecord_without_schema_and_column_list(True)
|
||||
test_tfrecord_with_both_schema_and_column_list(True, "string")
|
||||
|
|
|
@ -156,15 +156,15 @@ def test_tfrecord1():
|
|||
"""
|
||||
s = ds.Schema()
|
||||
s.add_column("line", "string", [])
|
||||
s.add_column("words", "string", [2, 2])
|
||||
s.add_column("words", "string", [-1])
|
||||
s.add_column("chinese", "string", [])
|
||||
|
||||
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
|
||||
|
||||
for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert d["line"].shape == (1,)
|
||||
assert d["line"].shape == line[i].shape
|
||||
assert d["words"].shape == words[i].shape
|
||||
assert d["chinese"].shape == (1,)
|
||||
assert d["chinese"].shape == chinese[i].shape
|
||||
np.testing.assert_array_equal(line[i], d["line"])
|
||||
np.testing.assert_array_equal(words[i], d["words"])
|
||||
np.testing.assert_array_equal(chinese[i], d["chinese"])
|
||||
|
@ -195,17 +195,17 @@ def test_tfrecord3():
|
|||
"""
|
||||
s = ds.Schema()
|
||||
s.add_column("line", mstype.string, [])
|
||||
s.add_column("words", mstype.string, [2, 2])
|
||||
s.add_column("words", mstype.string, [-1, 2])
|
||||
s.add_column("chinese", mstype.string, [])
|
||||
|
||||
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
|
||||
|
||||
for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert d["line"].shape == (1,)
|
||||
assert d["words"].shape == words[i].shape
|
||||
assert d["chinese"].shape == (1,)
|
||||
assert d["line"].shape == line[i].shape
|
||||
assert d["words"].shape == words[i].reshape([2, 2]).shape
|
||||
assert d["chinese"].shape == chinese[i].shape
|
||||
np.testing.assert_array_equal(line[i], d["line"])
|
||||
np.testing.assert_array_equal(words[i], d["words"])
|
||||
np.testing.assert_array_equal(words[i].reshape([2, 2]), d["words"])
|
||||
np.testing.assert_array_equal(chinese[i], d["chinese"])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue