Cleanup dataset UT: restore config support

This commit is contained in:
Cathy Wong 2020-05-25 15:10:03 -04:00
parent 1fdb3aea5a
commit f891e1755c
4 changed files with 106 additions and 4 deletions

View File

@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_basic():
"""
Test basic configuration functions
"""
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
prefetch_size_original = ds.config.get_prefetch_size()
seed_original = ds.config.get_seed()
ds.config.load('../data/dataset/declient.cfg')
# assert ds.config.get_rows_per_buffer() == 32
@ -50,6 +58,11 @@ def test_basic():
assert ds.config.get_prefetch_size() == 4
assert ds.config.get_seed() == 5
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_prefetch_size(prefetch_size_original)
ds.config.set_seed(seed_original)
def test_get_seed():
"""
@ -62,6 +75,9 @@ def test_pipeline():
"""
Test that our configuration pipeline works when we set parameters at different locations in dataset code
"""
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(2)
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
@ -85,6 +101,9 @@ def test_pipeline():
except IOError:
logger.info("Error while deleting: {}".format(f))
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
def test_deterministic_run_fail():
"""
@ -92,6 +111,10 @@ def test_deterministic_run_fail():
"""
logger.info("test_deterministic_run_fail")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
# when we set the seed all operations within our dataset should be deterministic
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
@ -120,12 +143,21 @@ def test_deterministic_run_fail():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_run_pass():
"""
Test deterministic run with with setting the seed
"""
logger.info("test_deterministic_run_pass")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
@ -152,13 +184,23 @@ def test_deterministic_run_pass():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_seed_undeterministic():
"""
Test seed with num parallel workers in c, this test is expected to fail some of the time
"""
logger.info("test_seed_undeterministic")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@ -178,6 +220,10 @@ def test_seed_undeterministic():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_run_distribution():
"""
@ -185,6 +231,10 @@ def test_deterministic_run_distribution():
"""
logger.info("test_deterministic_run_distribution")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
# when we set the seed all operations within our dataset should be deterministic
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
@ -206,12 +256,21 @@ def test_deterministic_run_distribution():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_python_seed():
"""
Test deterministic execution with seed in python
"""
logger.info("deterministic_random_crop_op_python_2")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
@ -242,12 +301,20 @@ def test_deterministic_python_seed():
np.testing.assert_equal(data1_output, data2_output)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_python_seed_multi_thread():
"""
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
"""
logger.info("deterministic_random_crop_op_python_2")
# Save original configuration values
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
# when we set the seed all operations within our dataset should be deterministic
# First dataset
@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)
# Restore original configuration values
ds.config.set_seed(seed_original)
if __name__ == '__main__':
test_basic()

View File

@ -14,6 +14,8 @@
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
@ -38,7 +40,7 @@ def test_textline_dataset_all_file():
def test_textline_dataset_totext():
ds.config.set_num_parallel_workers(4)
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is a text file.", "Another file.",
@ -48,6 +50,8 @@ def test_textline_dataset_totext():
assert (str == line[count])
count += 1
assert (count == 5)
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_textline_dataset_num_samples():

View File

@ -14,6 +14,8 @@
# ==============================================================================
import pytest
import mindspore.dataset as ds
from util import config_get_set_num_parallel_workers
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
@ -80,7 +82,7 @@ def test_unmappable_split():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
ds.config.set_num_parallel_workers(4)
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1], randomize=False)
@ -122,6 +124,9 @@ def test_unmappable_split():
assert s1_output == text_file_data[0:2]
assert s2_output == text_file_data[2:]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file)

View File

@ -15,11 +15,11 @@
import hashlib
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import os
# import jsbeautifier
import mindspore.dataset as ds
from mindspore import log as logger
# These are the column names defined in the testTFTestAllTypes dataset
@ -221,3 +221,26 @@ def visualize(image_original, image_transformed):
plt.title("Transformed image")
plt.show()
def config_get_set_seed(seed_new):
"""
Get and return the original configuration seed value.
Set the new configuration seed value.
"""
seed_original = ds.config.get_seed()
ds.config.set_seed(seed_new)
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
return seed_original
def config_get_set_num_parallel_workers(num_parallel_workers_new):
"""
Get and return the original configuration num_parallel_workers value.
Set the new configuration num_parallel_workers value.
"""
num_parallel_workers_original = ds.config.get_num_parallel_workers()
ds.config.set_num_parallel_workers(num_parallel_workers_new)
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
num_parallel_workers_new))
return num_parallel_workers_original