add function

pr comment fix

revert graphengine
This commit is contained in:
tony_liu2 2020-08-14 14:15:51 -04:00
parent 2cd99c2829
commit 70bfd506a1
3 changed files with 45 additions and 19 deletions

View File

@ -24,6 +24,7 @@ import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
import mindspore.dataset.transforms.vision.py_transforms as py_vision
from mindspore import log as logger
from util import dataset_equal
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
@ -139,8 +140,7 @@ def test_deterministic_run_fail():
data2 = data2.map(input_columns=["image"], operations=random_crop_op)
try:
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
dataset_equal(data1, data2, 0)
except Exception as e:
# two datasets split the number out of the sequence a
@ -181,8 +181,7 @@ def test_seed_undeterministic():
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
try:
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
dataset_equal(data1, data2, 0)
except Exception as e:
# two datasets both use numbers from the generated sequence "a"
logger.info("Got an exception in DE: {}".format(str(e)))
@ -221,8 +220,7 @@ def test_seed_deterministic():
random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
data2 = data2.map(input_columns=["image"], operations=random_crop_op2)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
dataset_equal(data1, data2, 0)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
@ -257,8 +255,7 @@ def test_deterministic_run_distribution():
random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1)
data2 = data2.map(input_columns=["image"], operations=random_horizontal_flip_op2)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
dataset_equal(data1, data2, 0)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)

View File

@ -21,7 +21,7 @@ import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
from util import diff_mse
from util import dataset_equal_with_function
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
@ -52,16 +52,7 @@ def test_one_hot():
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False)
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
assert len(item1) == len(item2)
label1 = item1["label"]
label2 = one_hot(item2["label"][0], depth)
mse = diff_mse(label1, label2)
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
assert mse == 0
num_iter += 1
assert num_iter == 3
assert dataset_equal_with_function(data1, data2, 0, one_hot, depth)
def test_one_hot_post_aug():
"""

View File

@ -16,6 +16,7 @@
import hashlib
import json
import os
import itertools
from enum import Enum
import matplotlib.pyplot as plt
import matplotlib.patches as patches
@ -397,3 +398,40 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
except RuntimeError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert expected_error in str(error)
#return true if datasets are equal
def dataset_equal(data1, data2, mse_threshold):
if data1.get_dataset_size() != data2.get_dataset_size():
return False
equal = True
for item1, item2 in itertools.zip_longest(data1, data2):
for column1, column2 in itertools.zip_longest(item1, item2):
mse = diff_mse(column1, column2)
if mse > mse_threshold:
equal = False
break
if not equal:
break
return equal
# return true if datasets are equal after modification to target
# params: data_unchanged - dataset kept unchanged
# data_target - dataset to be modified by foo
# mse_threshold - maximum allowable value of mse
# foo - function applied to data_target columns BEFORE compare
# foo_args - arguments passed into foo
def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
return False
equal = True
for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
for column1, column2 in itertools.zip_longest(item1, item2):
# note the function is to be applied to the second dataset
column2 = foo(column2, *foo_args)
mse = diff_mse(column1, column2)
if mse > mse_threshold:
equal = False
break
if not equal:
break
return equal