From f891e1755cf9ea1506add6a2026956eb0fc97daa Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Mon, 25 May 2020 15:10:03 -0400 Subject: [PATCH] Cleanup dataset UT: restore config support --- tests/ut/python/dataset/test_config.py | 70 +++++++++++++++++++ .../dataset/test_datasets_textfileop.py | 6 +- tests/ut/python/dataset/test_split.py | 7 +- tests/ut/python/dataset/util.py | 27 ++++++- 4 files changed, 106 insertions(+), 4 deletions(-) diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index f17b830edd3..c5413c07db9 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def test_basic(): + """ + Test basic configuration functions + """ + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + prefetch_size_original = ds.config.get_prefetch_size() + seed_original = ds.config.get_seed() + ds.config.load('../data/dataset/declient.cfg') # assert ds.config.get_rows_per_buffer() == 32 @@ -50,6 +58,11 @@ def test_basic(): assert ds.config.get_prefetch_size() == 4 assert ds.config.get_seed() == 5 + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + ds.config.set_prefetch_size(prefetch_size_original) + ds.config.set_seed(seed_original) + def test_get_seed(): """ @@ -62,6 +75,9 @@ def test_pipeline(): """ Test that our configuration pipeline works when we set parameters at different locations in dataset code """ + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) ds.config.set_num_parallel_workers(2) data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) @@ -85,6 +101,9 @@ def test_pipeline(): except IOError: logger.info("Error while deleting: {}".format(f)) + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + def test_deterministic_run_fail(): """ @@ -92,6 +111,10 @@ def test_deterministic_run_fail(): """ logger.info("test_deterministic_run_fail") + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + seed_original = ds.config.get_seed() + # when we set the seed all operations within our dataset should be deterministic ds.config.set_seed(0) ds.config.set_num_parallel_workers(1) @@ -120,12 +143,21 @@ def test_deterministic_run_fail(): logger.info("Got an exception in DE: {}".format(str(e))) assert "Array" in str(e) + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + ds.config.set_seed(seed_original) + def test_deterministic_run_pass(): """ Test deterministic run with with setting the seed """ logger.info("test_deterministic_run_pass") + + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + seed_original = ds.config.get_seed() + ds.config.set_seed(0) ds.config.set_num_parallel_workers(1) @@ -152,13 +184,23 @@ def test_deterministic_run_pass(): logger.info("Got an exception in DE: {}".format(str(e))) assert "Array" in str(e) + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + ds.config.set_seed(seed_original) + def test_seed_undeterministic(): """ Test seed with num parallel workers in c, this test is expected to fail some of the time """ logger.info("test_seed_undeterministic") + + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + seed_original = ds.config.get_seed() + ds.config.set_seed(0) + ds.config.set_num_parallel_workers(1) # First dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -178,6 +220,10 @@ def test_seed_undeterministic(): for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): np.testing.assert_equal(item1["image"], item2["image"]) + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + ds.config.set_seed(seed_original) + def test_deterministic_run_distribution(): """ @@ -185,6 +231,10 @@ def test_deterministic_run_distribution(): """ logger.info("test_deterministic_run_distribution") + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + seed_original = ds.config.get_seed() + # when we set the seed all operations within our dataset should be deterministic ds.config.set_seed(0) ds.config.set_num_parallel_workers(1) @@ -206,12 +256,21 @@ def test_deterministic_run_distribution(): for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): np.testing.assert_equal(item1["image"], item2["image"]) + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + ds.config.set_seed(seed_original) + def test_deterministic_python_seed(): """ Test deterministic execution with seed in python """ logger.info("deterministic_random_crop_op_python_2") + + # Save original configuration values + num_parallel_workers_original = ds.config.get_num_parallel_workers() + seed_original = ds.config.get_seed() + ds.config.set_seed(0) ds.config.set_num_parallel_workers(1) @@ -242,12 +301,20 @@ def test_deterministic_python_seed(): np.testing.assert_equal(data1_output, data2_output) + # Restore original configuration values + ds.config.set_num_parallel_workers(num_parallel_workers_original) + ds.config.set_seed(seed_original) + def test_deterministic_python_seed_multi_thread(): """ Test deterministic execution with seed in python, this fails with multi-thread pyfunc run """ logger.info("deterministic_random_crop_op_python_2") + + # Save original configuration values + seed_original = ds.config.get_seed() + ds.config.set_seed(0) # when we set the seed all operations within our dataset should be deterministic # First dataset @@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread(): logger.info("Got an exception in DE: {}".format(str(e))) assert "Array" in str(e) + # Restore original configuration values + ds.config.set_seed(seed_original) + if __name__ == '__main__': test_basic() diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py index d9b5d83a25f..6ba7594799f 100644 --- a/tests/ut/python/dataset/test_datasets_textfileop.py +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -14,6 +14,8 @@ # ============================================================================== import mindspore.dataset as ds from mindspore import log as logger +from util import config_get_set_num_parallel_workers + DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" @@ -38,7 +40,7 @@ def test_textline_dataset_all_file(): def test_textline_dataset_totext(): - ds.config.set_num_parallel_workers(4) + original_num_parallel_workers = config_get_set_num_parallel_workers(4) data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) count = 0 line = ["This is a text file.", "Another file.", @@ -48,6 +50,8 @@ def test_textline_dataset_totext(): assert (str == line[count]) count += 1 assert (count == 5) + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) def test_textline_dataset_num_samples(): diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index f2ff8b64971..fa28b491815 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -14,6 +14,8 @@ # ============================================================================== import pytest import mindspore.dataset as ds +from util import config_get_set_num_parallel_workers + # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] # the label of each image is [0,0,0,1,1] each image can be uniquely identified @@ -80,7 +82,7 @@ def test_unmappable_split(): text_file_dataset_path = "../data/dataset/testTextFileDataset/*" text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] - ds.config.set_num_parallel_workers(4) + original_num_parallel_workers = config_get_set_num_parallel_workers(4) d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) s1, s2 = d.split([4, 1], randomize=False) @@ -122,6 +124,9 @@ def test_unmappable_split(): assert s1_output == text_file_data[0:2] assert s2_output == text_file_data[2:] + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + def test_mappable_invalid_input(): d = ds.ManifestDataset(manifest_file) diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index 3212a11dc32..feb1e7b4061 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -15,11 +15,11 @@ import hashlib import json +import os import matplotlib.pyplot as plt import numpy as np -import os - # import jsbeautifier +import mindspore.dataset as ds from mindspore import log as logger # These are the column names defined in the testTFTestAllTypes dataset @@ -221,3 +221,26 @@ def visualize(image_original, image_transformed): plt.title("Transformed image") plt.show() + + +def config_get_set_seed(seed_new): + """ + Get and return the original configuration seed value. + Set the new configuration seed value. + """ + seed_original = ds.config.get_seed() + ds.config.set_seed(seed_new) + logger.info("seed: original = {} new = {} ".format(seed_original, seed_new)) + return seed_original + + +def config_get_set_num_parallel_workers(num_parallel_workers_new): + """ + Get and return the original configuration num_parallel_workers value. + Set the new configuration num_parallel_workers value. + """ + num_parallel_workers_original = ds.config.get_num_parallel_workers() + ds.config.set_num_parallel_workers(num_parallel_workers_new) + logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, + num_parallel_workers_new)) + return num_parallel_workers_original