forked from mindspore-Ecosystem/mindspore
!21755 Fix batch usability problem
Merge pull request !21755 from xiefangqi/md_fix_batch_usability
This commit is contained in:
commit
f99effd430
|
@ -2151,6 +2151,9 @@ class BatchDataset(Dataset):
|
|||
Per iterator bootstrap callback.
|
||||
"""
|
||||
if self.python_multiprocessing:
|
||||
if self.per_batch_map is None:
|
||||
logger.warning("per_batch_map is None so python_multiprocessing does not work.")
|
||||
return
|
||||
arg_q_list = []
|
||||
res_q_list = []
|
||||
|
||||
|
|
|
@ -238,6 +238,23 @@ def test_batch_12():
|
|||
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_batch_13():
|
||||
"""
|
||||
Test batch: python_multiprocessing is True and does not work for per_batch_map is None
|
||||
"""
|
||||
logger.info("test_batch_12")
|
||||
# define parameters
|
||||
batch_size = True
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
||||
data1 = data1.batch(batch_size=batch_size, python_multiprocessing=True)
|
||||
|
||||
assert sum([1 for _ in data1]) == 12
|
||||
filename = "batch_12_result.npz"
|
||||
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_batch_exception_01():
|
||||
"""
|
||||
Test batch exception: num_parallel_workers=0
|
||||
|
@ -493,6 +510,7 @@ if __name__ == '__main__':
|
|||
test_batch_10()
|
||||
test_batch_11()
|
||||
test_batch_12()
|
||||
test_batch_13()
|
||||
test_batch_exception_01()
|
||||
test_batch_exception_02()
|
||||
test_batch_exception_03()
|
||||
|
|
Loading…
Reference in New Issue