forked from OSSInnovation/mindspore
!503 The num_samples and numRows in schema for TFRecordDataset are conflict
Merge pull request !503 from qianlong21st/fix_numRows_num_samples
This commit is contained in:
commit
6b0ff88b1c
|
@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
|
||||||
std::ifstream in(schemaFile);
|
std::ifstream in(schemaFile);
|
||||||
nlohmann::json js;
|
nlohmann::json js;
|
||||||
in >> js;
|
in >> js;
|
||||||
num_rows = js.value("numRows", 0);
|
if (js.find("numRows") == js.end()) {
|
||||||
|
num_rows = MAX_INTEGER_INT32;
|
||||||
|
} else {
|
||||||
|
num_rows = js.value("numRows", 0);
|
||||||
|
}
|
||||||
if (num_rows == 0) {
|
if (num_rows == 0) {
|
||||||
std::string err_msg =
|
std::string err_msg =
|
||||||
"Storage client has not properly done dataset "
|
"Storage client has not properly done dataset "
|
||||||
|
|
|
@ -163,6 +163,9 @@ Status TFReaderOp::Init() {
|
||||||
if (total_rows_ == 0) {
|
if (total_rows_ == 0) {
|
||||||
total_rows_ = data_schema_->num_rows();
|
total_rows_ = data_schema_->num_rows();
|
||||||
}
|
}
|
||||||
|
if (total_rows_ < 0) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0");
|
||||||
|
}
|
||||||
|
|
||||||
// Build the index with our files such that each file corresponds to a key id.
|
// Build the index with our files such that each file corresponds to a key id.
|
||||||
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
|
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
|
||||||
|
|
|
@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_files (list[str]): List of files to be read.
|
dataset_files (list[str]): List of files to be read.
|
||||||
schema (str): Path to the json schema file.
|
schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset.
|
||||||
distribution (str, optional): Path of distribution config file (default="").
|
distribution (str, optional): Path of distribution config file (default="").
|
||||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
|
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
|
||||||
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
|
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
|
||||||
|
@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset):
|
||||||
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
|
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
|
||||||
If the schema is not provided, the meta data from the TFData file is considered the schema.
|
If the schema is not provided, the meta data from the TFData file is considered the schema.
|
||||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
|
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
|
||||||
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
num_samples (int, optional): number of samples(rows) to read (default=None).
|
||||||
|
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
|
||||||
|
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
|
||||||
|
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
|
||||||
num_parallel_workers (int, optional): number of workers to read the data
|
num_parallel_workers (int, optional): number of workers to read the data
|
||||||
(default=None, number set in the config).
|
(default=None, number set in the config).
|
||||||
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
|
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
|
||||||
|
@ -2711,10 +2714,10 @@ class Schema:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, schema_file=None):
|
def __init__(self, schema_file=None):
|
||||||
|
self.num_rows = None
|
||||||
if schema_file is None:
|
if schema_file is None:
|
||||||
self.columns = []
|
self.columns = []
|
||||||
self.dataset_type = ''
|
self.dataset_type = ''
|
||||||
self.num_rows = 0
|
|
||||||
else:
|
else:
|
||||||
if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
|
if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
|
||||||
raise ValueError("The file %s does not exist or permission denied!" % schema_file)
|
raise ValueError("The file %s does not exist or permission denied!" % schema_file)
|
||||||
|
@ -2859,6 +2862,9 @@ class Schema:
|
||||||
raise RuntimeError("DatasetType field is missing.")
|
raise RuntimeError("DatasetType field is missing.")
|
||||||
if self.columns is None:
|
if self.columns is None:
|
||||||
raise RuntimeError("Columns are missing.")
|
raise RuntimeError("Columns are missing.")
|
||||||
|
if self.num_rows is not None:
|
||||||
|
if not isinstance(self.num_rows, int) or self.num_rows <= 0:
|
||||||
|
raise ValueError("numRows must be greater than 0")
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.to_json()
|
return self.to_json()
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
{
|
||||||
|
"datasetType": "TF",
|
||||||
|
"columns": {
|
||||||
|
"col_sint16": {
|
||||||
|
"type": "int16",
|
||||||
|
"rank": 1,
|
||||||
|
"shape": [1]
|
||||||
|
},
|
||||||
|
"col_sint32": {
|
||||||
|
"type": "int32",
|
||||||
|
"rank": 1,
|
||||||
|
"shape": [1]
|
||||||
|
},
|
||||||
|
"col_sint64": {
|
||||||
|
"type": "int64",
|
||||||
|
"rank": 1,
|
||||||
|
"shape": [1]
|
||||||
|
},
|
||||||
|
"col_float": {
|
||||||
|
"type": "float32",
|
||||||
|
"rank": 1,
|
||||||
|
"shape": [1]
|
||||||
|
},
|
||||||
|
"col_1d": {
|
||||||
|
"type": "int64",
|
||||||
|
"rank": 1,
|
||||||
|
"shape": [2]
|
||||||
|
},
|
||||||
|
"col_2d": {
|
||||||
|
"type": "int64",
|
||||||
|
"rank": 2,
|
||||||
|
"shape": [2, 2]
|
||||||
|
},
|
||||||
|
"col_3d": {
|
||||||
|
"type": "int64",
|
||||||
|
"rank": 3,
|
||||||
|
"shape": [2, 2, 2]
|
||||||
|
},
|
||||||
|
"col_binary": {
|
||||||
|
"type": "uint8",
|
||||||
|
"rank": 1,
|
||||||
|
"shape": [1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"datasetType": "TF",
|
||||||
|
"columns": {
|
||||||
|
"image": {
|
||||||
|
"type": "uint8",
|
||||||
|
"rank": 1,
|
||||||
|
"t_impl": "cvmat"
|
||||||
|
},
|
||||||
|
"label" : {
|
||||||
|
"type": "uint64",
|
||||||
|
"rank": 1,
|
||||||
|
"t_impl": "flex"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -37,3 +37,15 @@ def test_case_storage():
|
||||||
|
|
||||||
filename = "storage_result.npz"
|
filename = "storage_result.npz"
|
||||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_no_rows():
|
||||||
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
|
||||||
|
|
||||||
|
dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||||
|
assert dataset.get_dataset_size() == 3
|
||||||
|
count = 0
|
||||||
|
for data in dataset.create_tuple_iterator():
|
||||||
|
count += 1
|
||||||
|
assert count == 3
|
||||||
|
|
|
@ -37,6 +37,36 @@ def test_case_tf_shape():
|
||||||
assert (len(output_shape[-1]) == 1)
|
assert (len(output_shape[-1]) == 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_tf_read_all_dataset():
|
||||||
|
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
|
||||||
|
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||||
|
assert ds1.get_dataset_size() == 12
|
||||||
|
count = 0
|
||||||
|
for data in ds1.create_tuple_iterator():
|
||||||
|
count += 1
|
||||||
|
assert count == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_num_samples():
|
||||||
|
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
|
||||||
|
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
|
||||||
|
assert ds1.get_dataset_size() == 8
|
||||||
|
count = 0
|
||||||
|
for data in ds1.create_dict_iterator():
|
||||||
|
count += 1
|
||||||
|
assert count == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_num_samples2():
|
||||||
|
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
|
||||||
|
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||||
|
assert ds1.get_dataset_size() == 7
|
||||||
|
count = 0
|
||||||
|
for data in ds1.create_dict_iterator():
|
||||||
|
count += 1
|
||||||
|
assert count == 7
|
||||||
|
|
||||||
|
|
||||||
def test_case_tf_shape_2():
|
def test_case_tf_shape_2():
|
||||||
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
|
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
|
||||||
ds1 = ds1.batch(2)
|
ds1 = ds1.batch(2)
|
||||||
|
|
Loading…
Reference in New Issue