!21755 Fix batch usability problem

Merge pull request !21755 from xiefangqi/md_fix_batch_usability
This commit is contained in:
i-robot 2021-08-13 07:21:49 +00:00 committed by Gitee
commit f99effd430
2 changed files with 21 additions and 0 deletions

View File

@ -2151,6 +2151,9 @@ class BatchDataset(Dataset):
Per iterator bootstrap callback.
"""
if self.python_multiprocessing:
if self.per_batch_map is None:
logger.warning("per_batch_map is None so python_multiprocessing does not work.")
return
arg_q_list = []
res_q_list = []

View File

@ -238,6 +238,23 @@ def test_batch_12():
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_13():
"""
Test batch: python_multiprocessing is True and does not work for per_batch_map is None
"""
logger.info("test_batch_12")
# define parameters
batch_size = True
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size, python_multiprocessing=True)
assert sum([1 for _ in data1]) == 12
filename = "batch_12_result.npz"
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_exception_01():
"""
Test batch exception: num_parallel_workers=0
@ -493,6 +510,7 @@ if __name__ == '__main__':
test_batch_10()
test_batch_11()
test_batch_12()
test_batch_13()
test_batch_exception_01()
test_batch_exception_02()
test_batch_exception_03()