change dataset_file into dataset_files for MIndDataset

This commit is contained in:
ms_yan 2021-12-10 17:45:40 +08:00
parent db7d28f5c8
commit a85113c093
4 changed files with 56 additions and 56 deletions

View File

@ -4379,7 +4379,7 @@ class MindDataset(MappableDataset):
The columns of generated dataset depend on the source MindRecord files.
Args:
dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for
dataset_files (Union[str, list[str]]): If dataset_file is a str, it represents for
a file name of one component of a mindrecord source, other files with identical source
in the same path will be found and loaded automatically. If dataset_file is a list,
it represents for a list of dataset files to be read directly.
@ -4453,15 +4453,15 @@ class MindDataset(MappableDataset):
Examples:
>>> mind_dataset_dir = ["/path/to/mind_dataset_file"] # contains 1 or multiple MindRecord files
>>> dataset = ds.MindDataset(dataset_file=mind_dataset_dir)
>>> dataset = ds.MindDataset(dataset_files=mind_dataset_dir)
"""
def parse(self, children=None):
return cde.MindDataNode(self.dataset_file, self.columns_list, self.sampler, self.new_padded_sample,
return cde.MindDataNode(self.dataset_files, self.columns_list, self.sampler, self.new_padded_sample,
self.num_padded, shuffle_to_shuffle_mode(self.shuffle_option))
@check_minddataset
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None,
def __init__(self, dataset_files, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None,
shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache)
@ -4472,11 +4472,11 @@ class MindDataset(MappableDataset):
raise ValueError("'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
"cannot be specified at the same time.")
self.shuffle_option = shuffle
if isinstance(dataset_file, list):
if isinstance(dataset_files, list):
self.load_dataset = False
else:
self.load_dataset = True
self.dataset_file = dataset_file
self.dataset_files = dataset_files
self.columns_list = replace_none(columns_list, [])
if shuffle is False:

View File

@ -515,7 +515,7 @@ def check_minddataset(method):
nreq_param_list = ['columns_list']
nreq_param_dict = ['padded_sample']
dataset_file = param_dict.get('dataset_file')
dataset_file = param_dict.get('dataset_files')
if isinstance(dataset_file, list):
if len(dataset_file) > 4096:
raise ValueError("length of dataset_file should be less than or equal to {}.".format(4096))

View File

@ -1235,7 +1235,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
@ -1252,7 +1252,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["source_sos_ids",
"source_sos_mask", "target_sos_ids"],
num_parallel_workers=num_readers,
@ -1270,7 +1270,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
assert num_iter == 6
num_readers = 1
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1288,7 +1288,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
assert num_iter == 6
num_readers = 3
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["target_sos_ids",
"image4", "source_sos_ids"],
num_parallel_workers=num_readers,
@ -1307,7 +1307,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
assert num_iter == 6
num_readers = 3
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["target_sos_ids", "image5",
"image4", "image3", "source_sos_ids"],
num_parallel_workers=num_readers,
@ -1326,7 +1326,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
assert num_iter == 6
num_readers = 1
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["target_eos_mask", "image5",
"image2", "source_sos_mask", "label"],
num_parallel_workers=num_readers,
@ -1345,7 +1345,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["label", "target_eos_mask", "image1", "target_eos_ids",
"source_sos_mask", "image2", "image4", "image3",
"source_sos_ids", "image5", "file_name"],
@ -1438,7 +1438,7 @@ def test_write_with_multi_bytes_and_MindDataset():
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
@ -1455,7 +1455,7 @@ def test_write_with_multi_bytes_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["image1", "image2", "image5"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1473,7 +1473,7 @@ def test_write_with_multi_bytes_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["image2", "image4"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1491,7 +1491,7 @@ def test_write_with_multi_bytes_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["image5", "image2"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1509,7 +1509,7 @@ def test_write_with_multi_bytes_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["image5", "image2", "label"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1527,7 +1527,7 @@ def test_write_with_multi_bytes_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["image4", "image5",
"image2", "image3", "file_name"],
num_parallel_workers=num_readers,
@ -1633,7 +1633,7 @@ def test_write_with_multi_array_and_MindDataset():
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
@ -1650,7 +1650,7 @@ def test_write_with_multi_array_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"],
@ -1670,7 +1670,7 @@ def test_write_with_multi_array_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["source_sos_ids",
"target_sos_ids",
"target_eos_mask"],
@ -1690,7 +1690,7 @@ def test_write_with_multi_array_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["target_eos_mask",
"source_eos_mask",
"source_sos_mask"],
@ -1710,7 +1710,7 @@ def test_write_with_multi_array_and_MindDataset():
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["target_eos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1728,7 +1728,7 @@ def test_write_with_multi_array_and_MindDataset():
assert num_iter == 6
num_readers = 1
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["target_eos_mask", "target_eos_ids",
"target_sos_mask", "target_sos_ids",
"source_eos_mask", "source_eos_ids",
@ -1880,7 +1880,7 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
@ -1901,7 +1901,7 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(
assert num_iter == 5
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["float32", "int32"],
num_parallel_workers=num_readers,
shuffle=False)
@ -1923,7 +1923,7 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(
assert num_iter == 5
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
data_set = ds.MindDataset(dataset_files=mindrecord_file_name,
columns_list=["float64", "int64"],
num_parallel_workers=num_readers,
shuffle=False)
@ -2016,7 +2016,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
# no shuffle parameter
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers)
assert data_set.get_dataset_size() == 52
num_iter = 0
@ -2047,7 +2047,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
# shuffle=False
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 52
@ -2079,7 +2079,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
# shuffle=True
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=True)
assert data_set.get_dataset_size() == 52
@ -2111,7 +2111,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
# shuffle=Shuffle.GLOBAL
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.GLOBAL)
assert data_set.get_dataset_size() == 52
@ -2143,7 +2143,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
# shuffle=Shuffle.INFILE
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.INFILE)
assert data_set.get_dataset_size() == 52
@ -2191,7 +2191,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
# shuffle=Shuffle.FILES
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.FILES)
assert data_set.get_dataset_size() == 52
@ -2255,7 +2255,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
# no shuffle parameter
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
num_shards=4,
shard_id=3)
@ -2271,7 +2271,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
# shuffle=False
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=False,
num_shards=4,
@ -2288,7 +2288,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
# shuffle=True
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=True,
num_shards=4,
@ -2305,7 +2305,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
# shuffle=Shuffle.GLOBAL
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.GLOBAL,
num_shards=4,
@ -2324,7 +2324,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
output_datas = []
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.INFILE,
num_shards=4,
@ -2383,7 +2383,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
data_list = []
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.FILES,
num_shards=4,
@ -2450,7 +2450,7 @@ def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
# no shuffle parameter
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
num_shards=4,
shard_id=shard_id)
@ -2472,7 +2472,7 @@ def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
# shuffle=False
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=False,
num_shards=4,
@ -2493,7 +2493,7 @@ def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
# shuffle=True
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=True,
num_shards=4,
@ -2516,7 +2516,7 @@ def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
# shuffle=Shuffle.GLOBAL
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.GLOBAL,
num_shards=4,
@ -2539,7 +2539,7 @@ def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
# shuffle=Shuffle.INFILE
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.INFILE,
num_shards=4,
@ -2565,7 +2565,7 @@ def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
datas_epoch3 = []
for shard_id in range(4):
num_readers = 2
data_set = ds.MindDataset(dataset_file=files,
data_set = ds.MindDataset(dataset_files=files,
num_parallel_workers=num_readers,
shuffle=ds.Shuffle.FILES,
num_shards=4,
@ -2628,7 +2628,7 @@ def test_field_is_null_numpy():
writer.write_raw_data(data)
writer.commit()
data_set = ds.MindDataset(dataset_file=file_name + "0",
data_set = ds.MindDataset(dataset_files=file_name + "0",
columns_list=["label", "array_a", "array_b", "array_d"],
num_parallel_workers=2,
shuffle=False)

View File

@ -94,7 +94,7 @@ def test_case_00():
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
data_value_to_list.append(new_data)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
assert d2.get_dataset_size() == 5
@ -143,7 +143,7 @@ def test_case_00():
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
data_value_to_list.append(new_data)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
assert d2.get_dataset_size() == 6
@ -291,7 +291,7 @@ def test_case_02(): # muti-bytes
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
data_value_to_list.append(new_data)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
assert d2.get_dataset_size() == 6
@ -333,7 +333,7 @@ def test_case_03():
d1.save(file_name_auto)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
@ -366,7 +366,7 @@ def type_tester(t):
data1.save(file_name_auto)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
@ -446,7 +446,7 @@ def test_case_07():
for x in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
tf_data.append(x)
d1.save(file_name_auto, FILES_NUM)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
mr_data = []
@ -503,7 +503,7 @@ def test_case_08():
d1.save(file_name_auto)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)
@ -532,7 +532,7 @@ def test_case_09():
d1.save(file_name_auto)
d2 = ds.MindDataset(dataset_file=file_name_auto,
d2 = ds.MindDataset(dataset_files=file_name_auto,
num_parallel_workers=num_readers,
shuffle=False)