forked from mindspore-Ecosystem/mindspore
!503 The num_samples and numRows in schema for TFRecordDataset are conflict
Merge pull request !503 from qianlong21st/fix_numRows_num_samples
This commit is contained in:
commit
6b0ff88b1c
|
@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
|
|||
std::ifstream in(schemaFile);
|
||||
nlohmann::json js;
|
||||
in >> js;
|
||||
num_rows = js.value("numRows", 0);
|
||||
if (js.find("numRows") == js.end()) {
|
||||
num_rows = MAX_INTEGER_INT32;
|
||||
} else {
|
||||
num_rows = js.value("numRows", 0);
|
||||
}
|
||||
if (num_rows == 0) {
|
||||
std::string err_msg =
|
||||
"Storage client has not properly done dataset "
|
||||
|
|
|
@ -163,6 +163,9 @@ Status TFReaderOp::Init() {
|
|||
if (total_rows_ == 0) {
|
||||
total_rows_ = data_schema_->num_rows();
|
||||
}
|
||||
if (total_rows_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0");
|
||||
}
|
||||
|
||||
// Build the index with our files such that each file corresponds to a key id.
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
|
||||
|
|
|
@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset):
|
|||
|
||||
Args:
|
||||
dataset_files (list[str]): List of files to be read.
|
||||
schema (str): Path to the json schema file.
|
||||
schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset.
|
||||
distribution (str, optional): Path of distribution config file (default="").
|
||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
|
||||
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
|
||||
|
@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset):
|
|||
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
|
||||
If the schema is not provided, the meta data from the TFData file is considered the schema.
|
||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
|
||||
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
||||
num_samples (int, optional): number of samples(rows) to read (default=None).
|
||||
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
|
||||
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
|
||||
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
|
||||
num_parallel_workers (int, optional): number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
|
||||
|
@ -2711,10 +2714,10 @@ class Schema:
|
|||
"""
|
||||
|
||||
def __init__(self, schema_file=None):
|
||||
self.num_rows = None
|
||||
if schema_file is None:
|
||||
self.columns = []
|
||||
self.dataset_type = ''
|
||||
self.num_rows = 0
|
||||
else:
|
||||
if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
|
||||
raise ValueError("The file %s does not exist or permission denied!" % schema_file)
|
||||
|
@ -2859,6 +2862,9 @@ class Schema:
|
|||
raise RuntimeError("DatasetType field is missing.")
|
||||
if self.columns is None:
|
||||
raise RuntimeError("Columns are missing.")
|
||||
if self.num_rows is not None:
|
||||
if not isinstance(self.num_rows, int) or self.num_rows <= 0:
|
||||
raise ValueError("numRows must be greater than 0")
|
||||
|
||||
def __str__(self):
|
||||
return self.to_json()
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
{
|
||||
"datasetType": "TF",
|
||||
"columns": {
|
||||
"col_sint16": {
|
||||
"type": "int16",
|
||||
"rank": 1,
|
||||
"shape": [1]
|
||||
},
|
||||
"col_sint32": {
|
||||
"type": "int32",
|
||||
"rank": 1,
|
||||
"shape": [1]
|
||||
},
|
||||
"col_sint64": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [1]
|
||||
},
|
||||
"col_float": {
|
||||
"type": "float32",
|
||||
"rank": 1,
|
||||
"shape": [1]
|
||||
},
|
||||
"col_1d": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"shape": [2]
|
||||
},
|
||||
"col_2d": {
|
||||
"type": "int64",
|
||||
"rank": 2,
|
||||
"shape": [2, 2]
|
||||
},
|
||||
"col_3d": {
|
||||
"type": "int64",
|
||||
"rank": 3,
|
||||
"shape": [2, 2, 2]
|
||||
},
|
||||
"col_binary": {
|
||||
"type": "uint8",
|
||||
"rank": 1,
|
||||
"shape": [1]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"datasetType": "TF",
|
||||
"columns": {
|
||||
"image": {
|
||||
"type": "uint8",
|
||||
"rank": 1,
|
||||
"t_impl": "cvmat"
|
||||
},
|
||||
"label" : {
|
||||
"type": "uint64",
|
||||
"rank": 1,
|
||||
"t_impl": "flex"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -37,3 +37,15 @@ def test_case_storage():
|
|||
|
||||
filename = "storage_result.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_no_rows():
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
|
||||
|
||||
dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
assert dataset.get_dataset_size() == 3
|
||||
count = 0
|
||||
for data in dataset.create_tuple_iterator():
|
||||
count += 1
|
||||
assert count == 3
|
||||
|
|
|
@ -37,6 +37,36 @@ def test_case_tf_shape():
|
|||
assert (len(output_shape[-1]) == 1)
|
||||
|
||||
|
||||
def test_case_tf_read_all_dataset():
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
assert ds1.get_dataset_size() == 12
|
||||
count = 0
|
||||
for data in ds1.create_tuple_iterator():
|
||||
count += 1
|
||||
assert count == 12
|
||||
|
||||
|
||||
def test_case_num_samples():
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
|
||||
assert ds1.get_dataset_size() == 8
|
||||
count = 0
|
||||
for data in ds1.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 8
|
||||
|
||||
|
||||
def test_case_num_samples2():
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
assert ds1.get_dataset_size() == 7
|
||||
count = 0
|
||||
for data in ds1.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 7
|
||||
|
||||
|
||||
def test_case_tf_shape_2():
|
||||
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
|
||||
ds1 = ds1.batch(2)
|
||||
|
|
Loading…
Reference in New Issue