forked from mindspore-Ecosystem/mindspore
!4477 add dataset compare function to utils
Merge pull request !4477 from tony_liu2/staging
This commit is contained in:
commit
5453b40311
|
@ -24,6 +24,7 @@ import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||||
import mindspore.dataset.transforms.vision.py_transforms as py_vision
|
import mindspore.dataset.transforms.vision.py_transforms as py_vision
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from util import dataset_equal
|
||||||
|
|
||||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
@ -139,8 +140,7 @@ def test_deterministic_run_fail():
|
||||||
data2 = data2.map(input_columns=["image"], operations=random_crop_op)
|
data2 = data2.map(input_columns=["image"], operations=random_crop_op)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
dataset_equal(data1, data2, 0)
|
||||||
np.testing.assert_equal(item1["image"], item2["image"])
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# two datasets split the number out of the sequence a
|
# two datasets split the number out of the sequence a
|
||||||
|
@ -181,8 +181,7 @@ def test_seed_undeterministic():
|
||||||
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||||
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
|
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
|
||||||
try:
|
try:
|
||||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
dataset_equal(data1, data2, 0)
|
||||||
np.testing.assert_equal(item1["image"], item2["image"])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# two datasets both use numbers from the generated sequence "a"
|
# two datasets both use numbers from the generated sequence "a"
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
@ -221,8 +220,7 @@ def test_seed_deterministic():
|
||||||
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||||
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
|
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
|
||||||
|
|
||||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
dataset_equal(data1, data2, 0)
|
||||||
np.testing.assert_equal(item1["image"], item2["image"])
|
|
||||||
|
|
||||||
# Restore original configuration values
|
# Restore original configuration values
|
||||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||||
|
@ -257,8 +255,7 @@ def test_deterministic_run_distribution():
|
||||||
random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1)
|
random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1)
|
||||||
data2 = data2.map(input_columns=["image"], operations=random_horizontal_flip_op2)
|
data2 = data2.map(input_columns=["image"], operations=random_horizontal_flip_op2)
|
||||||
|
|
||||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
dataset_equal(data1, data2, 0)
|
||||||
np.testing.assert_equal(item1["image"], item2["image"])
|
|
||||||
|
|
||||||
# Restore original configuration values
|
# Restore original configuration values
|
||||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||||
|
|
|
@ -21,7 +21,7 @@ import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from util import diff_mse
|
from util import dataset_equal_with_function
|
||||||
|
|
||||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
@ -52,16 +52,7 @@ def test_one_hot():
|
||||||
# Second dataset
|
# Second dataset
|
||||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False)
|
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False)
|
||||||
|
|
||||||
num_iter = 0
|
assert dataset_equal_with_function(data1, data2, 0, one_hot, depth)
|
||||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
|
||||||
assert len(item1) == len(item2)
|
|
||||||
label1 = item1["label"]
|
|
||||||
label2 = one_hot(item2["label"][0], depth)
|
|
||||||
mse = diff_mse(label1, label2)
|
|
||||||
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
|
|
||||||
assert mse == 0
|
|
||||||
num_iter += 1
|
|
||||||
assert num_iter == 3
|
|
||||||
|
|
||||||
def test_one_hot_post_aug():
|
def test_one_hot_post_aug():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import itertools
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.patches as patches
|
import matplotlib.patches as patches
|
||||||
|
@ -397,3 +398,40 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert expected_error in str(error)
|
assert expected_error in str(error)
|
||||||
|
|
||||||
|
#return true if datasets are equal
|
||||||
|
def dataset_equal(data1, data2, mse_threshold):
|
||||||
|
if data1.get_dataset_size() != data2.get_dataset_size():
|
||||||
|
return False
|
||||||
|
equal = True
|
||||||
|
for item1, item2 in itertools.zip_longest(data1, data2):
|
||||||
|
for column1, column2 in itertools.zip_longest(item1, item2):
|
||||||
|
mse = diff_mse(column1, column2)
|
||||||
|
if mse > mse_threshold:
|
||||||
|
equal = False
|
||||||
|
break
|
||||||
|
if not equal:
|
||||||
|
break
|
||||||
|
return equal
|
||||||
|
|
||||||
|
# return true if datasets are equal after modification to target
|
||||||
|
# params: data_unchanged - dataset kept unchanged
|
||||||
|
# data_target - dataset to be modified by foo
|
||||||
|
# mse_threshold - maximum allowable value of mse
|
||||||
|
# foo - function applied to data_target columns BEFORE compare
|
||||||
|
# foo_args - arguments passed into foo
|
||||||
|
def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
|
||||||
|
if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
|
||||||
|
return False
|
||||||
|
equal = True
|
||||||
|
for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
|
||||||
|
for column1, column2 in itertools.zip_longest(item1, item2):
|
||||||
|
# note the function is to be applied to the second dataset
|
||||||
|
column2 = foo(column2, *foo_args)
|
||||||
|
mse = diff_mse(column1, column2)
|
||||||
|
if mse > mse_threshold:
|
||||||
|
equal = False
|
||||||
|
break
|
||||||
|
if not equal:
|
||||||
|
break
|
||||||
|
return equal
|
||||||
|
|
Loading…
Reference in New Issue