diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 6783eea2fdc..8a76df559b6 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -24,6 +24,7 @@ import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.py_transforms as py_vision from mindspore import log as logger +from util import dataset_equal DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" @@ -139,8 +140,7 @@ def test_deterministic_run_fail(): data2 = data2.map(input_columns=["image"], operations=random_crop_op) try: - for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - np.testing.assert_equal(item1["image"], item2["image"]) + dataset_equal(data1, data2, 0) except Exception as e: # two datasets split the number out of the sequence a @@ -181,8 +181,7 @@ def test_seed_undeterministic(): random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) data2 = data2.map(input_columns=["image"], operations=random_crop_op2) try: - for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - np.testing.assert_equal(item1["image"], item2["image"]) + dataset_equal(data1, data2, 0) except Exception as e: # two datasets both use numbers from the generated sequence "a" logger.info("Got an exception in DE: {}".format(str(e))) @@ -221,8 +220,7 @@ def test_seed_deterministic(): random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) data2 = data2.map(input_columns=["image"], operations=random_crop_op2) - for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - np.testing.assert_equal(item1["image"], item2["image"]) + dataset_equal(data1, data2, 0) # Restore original configuration values ds.config.set_num_parallel_workers(num_parallel_workers_original) @@ -257,8 +255,7 @@ def test_deterministic_run_distribution(): random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1) data2 = data2.map(input_columns=["image"], operations=random_horizontal_flip_op2) - for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - np.testing.assert_equal(item1["image"], item2["image"]) + dataset_equal(data1, data2, 0) # Restore original configuration values ds.config.set_num_parallel_workers(num_parallel_workers_original) diff --git a/tests/ut/python/dataset/test_onehot_op.py b/tests/ut/python/dataset/test_onehot_op.py index 44d98b0ae0a..9020663b06c 100644 --- a/tests/ut/python/dataset/test_onehot_op.py +++ b/tests/ut/python/dataset/test_onehot_op.py @@ -21,7 +21,7 @@ import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as data_trans import mindspore.dataset.transforms.vision.c_transforms as c_vision from mindspore import log as logger -from util import diff_mse +from util import dataset_equal_with_function DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" @@ -52,16 +52,7 @@ def test_one_hot(): # Second dataset data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False) - num_iter = 0 - for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): - assert len(item1) == len(item2) - label1 = item1["label"] - label2 = one_hot(item2["label"][0], depth) - mse = diff_mse(label1, label2) - logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse)) - assert mse == 0 - num_iter += 1 - assert num_iter == 3 + assert dataset_equal_with_function(data1, data2, 0, one_hot, depth) def test_one_hot_post_aug(): """ diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index 74009dbd053..533b353d838 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -16,6 +16,7 @@ import hashlib import json import os +import itertools from enum import Enum import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -397,3 +398,40 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): except RuntimeError as error: logger.info("Got an exception in DE: {}".format(str(error))) assert expected_error in str(error) + +#return true if datasets are equal +def dataset_equal(data1, data2, mse_threshold): + if data1.get_dataset_size() != data2.get_dataset_size(): + return False + equal = True + for item1, item2 in itertools.zip_longest(data1, data2): + for column1, column2 in itertools.zip_longest(item1, item2): + mse = diff_mse(column1, column2) + if mse > mse_threshold: + equal = False + break + if not equal: + break + return equal + +# return true if datasets are equal after modification to target +# params: data_unchanged - dataset kept unchanged +# data_target - dataset to be modified by foo +# mse_threshold - maximum allowable value of mse +# foo - function applied to data_target columns BEFORE compare +# foo_args - arguments passed into foo +def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args): + if data_unchanged.get_dataset_size() != data_target.get_dataset_size(): + return False + equal = True + for item1, item2 in itertools.zip_longest(data_unchanged, data_target): + for column1, column2 in itertools.zip_longest(item1, item2): + # note the function is to be applied to the second dataset + column2 = foo(column2, *foo_args) + mse = diff_mse(column1, column2) + if mse > mse_threshold: + equal = False + break + if not equal: + break + return equal