forked from mindspore-Ecosystem/mindspore
Cleanup dataset UT: restore config support
This commit is contained in:
parent
1fdb3aea5a
commit
f891e1755c
|
@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
|||
|
||||
|
||||
def test_basic():
|
||||
"""
|
||||
Test basic configuration functions
|
||||
"""
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
prefetch_size_original = ds.config.get_prefetch_size()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.load('../data/dataset/declient.cfg')
|
||||
|
||||
# assert ds.config.get_rows_per_buffer() == 32
|
||||
|
@ -50,6 +58,11 @@ def test_basic():
|
|||
assert ds.config.get_prefetch_size() == 4
|
||||
assert ds.config.get_seed() == 5
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_prefetch_size(prefetch_size_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_get_seed():
|
||||
"""
|
||||
|
@ -62,6 +75,9 @@ def test_pipeline():
|
|||
"""
|
||||
Test that our configuration pipeline works when we set parameters at different locations in dataset code
|
||||
"""
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_num_parallel_workers(2)
|
||||
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
|
||||
|
@ -85,6 +101,9 @@ def test_pipeline():
|
|||
except IOError:
|
||||
logger.info("Error while deleting: {}".format(f))
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
|
||||
|
||||
def test_deterministic_run_fail():
|
||||
"""
|
||||
|
@ -92,6 +111,10 @@ def test_deterministic_run_fail():
|
|||
"""
|
||||
logger.info("test_deterministic_run_fail")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
# when we set the seed all operations within our dataset should be deterministic
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
@ -120,12 +143,21 @@ def test_deterministic_run_fail():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_run_pass():
|
||||
"""
|
||||
Test deterministic run with with setting the seed
|
||||
"""
|
||||
logger.info("test_deterministic_run_pass")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
|
@ -152,13 +184,23 @@ def test_deterministic_run_pass():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_seed_undeterministic():
|
||||
"""
|
||||
Test seed with num parallel workers in c, this test is expected to fail some of the time
|
||||
"""
|
||||
logger.info("test_seed_undeterministic")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
@ -178,6 +220,10 @@ def test_seed_undeterministic():
|
|||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_run_distribution():
|
||||
"""
|
||||
|
@ -185,6 +231,10 @@ def test_deterministic_run_distribution():
|
|||
"""
|
||||
logger.info("test_deterministic_run_distribution")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
# when we set the seed all operations within our dataset should be deterministic
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
@ -206,12 +256,21 @@ def test_deterministic_run_distribution():
|
|||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_python_seed():
|
||||
"""
|
||||
Test deterministic execution with seed in python
|
||||
"""
|
||||
logger.info("deterministic_random_crop_op_python_2")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
|
@ -242,12 +301,20 @@ def test_deterministic_python_seed():
|
|||
|
||||
np.testing.assert_equal(data1_output, data2_output)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_python_seed_multi_thread():
|
||||
"""
|
||||
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
|
||||
"""
|
||||
logger.info("deterministic_random_crop_op_python_2")
|
||||
|
||||
# Save original configuration values
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
# when we set the seed all operations within our dataset should be deterministic
|
||||
# First dataset
|
||||
|
@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_basic()
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import config_get_set_num_parallel_workers
|
||||
|
||||
|
||||
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
|
||||
|
@ -38,7 +40,7 @@ def test_textline_dataset_all_file():
|
|||
|
||||
|
||||
def test_textline_dataset_totext():
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Another file.",
|
||||
|
@ -48,6 +50,8 @@ def test_textline_dataset_totext():
|
|||
assert (str == line[count])
|
||||
count += 1
|
||||
assert (count == 5)
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_textline_dataset_num_samples():
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from util import config_get_set_num_parallel_workers
|
||||
|
||||
|
||||
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
|
||||
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
|
||||
|
@ -80,7 +82,7 @@ def test_unmappable_split():
|
|||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
|
@ -122,6 +124,9 @@ def test_unmappable_split():
|
|||
|
||||
assert s1_output == text_file_data[0:2]
|
||||
assert s2_output == text_file_data[2:]
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_mappable_invalid_input():
|
||||
d = ds.ManifestDataset(manifest_file)
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# import jsbeautifier
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
# These are the column names defined in the testTFTestAllTypes dataset
|
||||
|
@ -221,3 +221,26 @@ def visualize(image_original, image_transformed):
|
|||
plt.title("Transformed image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def config_get_set_seed(seed_new):
|
||||
"""
|
||||
Get and return the original configuration seed value.
|
||||
Set the new configuration seed value.
|
||||
"""
|
||||
seed_original = ds.config.get_seed()
|
||||
ds.config.set_seed(seed_new)
|
||||
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
|
||||
return seed_original
|
||||
|
||||
|
||||
def config_get_set_num_parallel_workers(num_parallel_workers_new):
|
||||
"""
|
||||
Get and return the original configuration num_parallel_workers value.
|
||||
Set the new configuration num_parallel_workers value.
|
||||
"""
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_new)
|
||||
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
|
||||
num_parallel_workers_new))
|
||||
return num_parallel_workers_original
|
||||
|
|
Loading…
Reference in New Issue