forked from mindspore-Ecosystem/mindspore
!4477 add dataset compare function to utils
Merge pull request !4477 from tony_liu2/staging
This commit is contained in:
commit
5453b40311
|
@ -24,6 +24,7 @@ import mindspore.dataset as ds
|
|||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
import mindspore.dataset.transforms.vision.py_transforms as py_vision
|
||||
from mindspore import log as logger
|
||||
from util import dataset_equal
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
@ -139,8 +140,7 @@ def test_deterministic_run_fail():
|
|||
data2 = data2.map(input_columns=["image"], operations=random_crop_op)
|
||||
|
||||
try:
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
dataset_equal(data1, data2, 0)
|
||||
|
||||
except Exception as e:
|
||||
# two datasets split the number out of the sequence a
|
||||
|
@ -181,8 +181,7 @@ def test_seed_undeterministic():
|
|||
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
|
||||
try:
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
dataset_equal(data1, data2, 0)
|
||||
except Exception as e:
|
||||
# two datasets both use numbers from the generated sequence "a"
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
@ -221,8 +220,7 @@ def test_seed_deterministic():
|
|||
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
|
||||
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
dataset_equal(data1, data2, 0)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
|
@ -257,8 +255,7 @@ def test_deterministic_run_distribution():
|
|||
random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1)
|
||||
data2 = data2.map(input_columns=["image"], operations=random_horizontal_flip_op2)
|
||||
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
dataset_equal(data1, data2, 0)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
|
|
|
@ -21,7 +21,7 @@ import mindspore.dataset as ds
|
|||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
from util import diff_mse
|
||||
from util import dataset_equal_with_function
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
@ -52,16 +52,7 @@ def test_one_hot():
|
|||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False)
|
||||
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
assert len(item1) == len(item2)
|
||||
label1 = item1["label"]
|
||||
label2 = one_hot(item2["label"][0], depth)
|
||||
mse = diff_mse(label1, label2)
|
||||
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
|
||||
assert mse == 0
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
assert dataset_equal_with_function(data1, data2, 0, one_hot, depth)
|
||||
|
||||
def test_one_hot_post_aug():
|
||||
"""
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import itertools
|
||||
from enum import Enum
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
@ -397,3 +398,40 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
|
|||
except RuntimeError as error:
|
||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||
assert expected_error in str(error)
|
||||
|
||||
#return true if datasets are equal
|
||||
def dataset_equal(data1, data2, mse_threshold):
|
||||
if data1.get_dataset_size() != data2.get_dataset_size():
|
||||
return False
|
||||
equal = True
|
||||
for item1, item2 in itertools.zip_longest(data1, data2):
|
||||
for column1, column2 in itertools.zip_longest(item1, item2):
|
||||
mse = diff_mse(column1, column2)
|
||||
if mse > mse_threshold:
|
||||
equal = False
|
||||
break
|
||||
if not equal:
|
||||
break
|
||||
return equal
|
||||
|
||||
# return true if datasets are equal after modification to target
|
||||
# params: data_unchanged - dataset kept unchanged
|
||||
# data_target - dataset to be modified by foo
|
||||
# mse_threshold - maximum allowable value of mse
|
||||
# foo - function applied to data_target columns BEFORE compare
|
||||
# foo_args - arguments passed into foo
|
||||
def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
|
||||
if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
|
||||
return False
|
||||
equal = True
|
||||
for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
|
||||
for column1, column2 in itertools.zip_longest(item1, item2):
|
||||
# note the function is to be applied to the second dataset
|
||||
column2 = foo(column2, *foo_args)
|
||||
mse = diff_mse(column1, column2)
|
||||
if mse > mse_threshold:
|
||||
equal = False
|
||||
break
|
||||
if not equal:
|
||||
break
|
||||
return equal
|
||||
|
|
Loading…
Reference in New Issue