forked from mindspore-Ecosystem/mindspore
!1140 Cleanup dataset UT: resolve skipped test units
Merge pull request !1140 from cathwong/ckw_dataset_ut_unskip1
This commit is contained in:
commit
1501e20ec2
|
@ -12,10 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from util import save_and_check
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check
|
||||
|
||||
|
||||
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
|
@ -24,7 +24,7 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
|
|||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def skip_test_case_0():
|
||||
def test_2ops_repeat_shuffle():
|
||||
"""
|
||||
Test Repeat then Shuffle
|
||||
"""
|
||||
|
@ -43,11 +43,11 @@ def skip_test_case_0():
|
|||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
|
||||
filename = "test_case_0_result.npz"
|
||||
filename = "test_2ops_repeat_shuffle.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def skip_test_case_0_reverse():
|
||||
def skip_test_2ops_shuffle_repeat():
|
||||
"""
|
||||
Test Shuffle then Repeat
|
||||
"""
|
||||
|
@ -67,11 +67,11 @@ def skip_test_case_0_reverse():
|
|||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
filename = "test_case_0_reverse_result.npz"
|
||||
filename = "test_2ops_shuffle_repeat.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_1():
|
||||
def test_2ops_repeat_batch():
|
||||
"""
|
||||
Test Repeat then Batch
|
||||
"""
|
||||
|
@ -87,11 +87,11 @@ def test_case_1():
|
|||
data1 = data1.repeat(repeat_count)
|
||||
data1 = data1.batch(batch_size, drop_remainder=True)
|
||||
|
||||
filename = "test_case_1_result.npz"
|
||||
filename = "test_2ops_repeat_batch.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_1_reverse():
|
||||
def test_2ops_batch_repeat():
|
||||
"""
|
||||
Test Batch then Repeat
|
||||
"""
|
||||
|
@ -107,11 +107,11 @@ def test_case_1_reverse():
|
|||
data1 = data1.batch(batch_size, drop_remainder=True)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
filename = "test_case_1_reverse_result.npz"
|
||||
filename = "test_2ops_batch_repeat.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_2():
|
||||
def test_2ops_batch_shuffle():
|
||||
"""
|
||||
Test Batch then Shuffle
|
||||
"""
|
||||
|
@ -130,11 +130,11 @@ def test_case_2():
|
|||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
|
||||
filename = "test_case_2_result.npz"
|
||||
filename = "test_2ops_batch_shuffle.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_2_reverse():
|
||||
def test_2ops_shuffle_batch():
|
||||
"""
|
||||
Test Shuffle then Batch
|
||||
"""
|
||||
|
@ -153,5 +153,14 @@ def test_case_2_reverse():
|
|||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
data1 = data1.batch(batch_size, drop_remainder=True)
|
||||
|
||||
filename = "test_case_2_reverse_result.npz"
|
||||
filename = "test_2ops_shuffle_batch.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_2ops_repeat_shuffle()
|
||||
#test_2ops_shuffle_repeat()
|
||||
test_2ops_repeat_batch()
|
||||
test_2ops_batch_repeat()
|
||||
test_2ops_batch_shuffle()
|
||||
test_2ops_shuffle_batch()
|
||||
|
|
|
@ -12,41 +12,54 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def skip_test_exception():
|
||||
def test_exception_01():
|
||||
"""
|
||||
Test single exception with invalid input
|
||||
"""
|
||||
logger.info("test_exception_01")
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"])
|
||||
data = data.map(input_columns=["image"], operations=vision.Resize(100, 100))
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
data.create_tuple_iterator().get_next()
|
||||
assert "The shape size 1 of input tensor is invalid" in str(info.value)
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = data.map(input_columns=["image"], operations=vision.Resize(100, 100))
|
||||
assert "Invalid interpolation mode." in str(info.value)
|
||||
|
||||
|
||||
def test_sample_exception():
|
||||
def test_exception_02():
|
||||
"""
|
||||
Test multiple exceptions with invalid input
|
||||
"""
|
||||
logger.info("test_exception_02")
|
||||
num_samples = 0
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
assert "num_samples must be greater than 0" in str(info.value)
|
||||
|
||||
num_samples = -1
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
assert "num_samples must be greater than 0" in str(info.value)
|
||||
|
||||
num_samples = 1
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
data = data.map(input_columns=["image"], operations=vision.Decode())
|
||||
data = data.map(input_columns=["image"], operations=vision.Resize((100, 100)))
|
||||
# Confirm 1 sample in dataset
|
||||
assert sum([1 for _ in data]) == 1
|
||||
num_iters = 0
|
||||
for item in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iters += 1
|
||||
assert num_iters == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_exception()
|
||||
test_exception_01()
|
||||
test_exception_02()
|
||||
|
|
|
@ -261,11 +261,18 @@ def test_case_invalid_files():
|
|||
|
||||
if __name__ == '__main__':
|
||||
test_case_tf_shape()
|
||||
test_case_tf_read_all_dataset()
|
||||
test_case_num_samples()
|
||||
test_case_num_samples2()
|
||||
test_case_tf_shape_2()
|
||||
test_case_tf_file()
|
||||
test_case_tf_file_no_schema()
|
||||
test_case_tf_file_pad()
|
||||
test_tf_files()
|
||||
test_tf_record_schema()
|
||||
test_tf_record_shuffle()
|
||||
#test_tf_record_shard()
|
||||
test_tf_shard_equal_rows()
|
||||
test_case_tf_file_no_schema_columns_list()
|
||||
test_tf_record_schema_columns_list()
|
||||
test_case_invalid_files()
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from util import save_and_check_dict, save_and_check_md5
|
||||
from mindspore import log as logger
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_dict, save_and_check_md5
|
||||
|
||||
|
||||
|
||||
# Dataset in DIR_1 has 5 rows and 5 columns
|
||||
DATA_DIR_1 = ["../data/dataset/testTFBert5Rows1/5TFDatas.data"]
|
||||
|
@ -147,7 +148,7 @@ def test_zip_exception_01():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
def skip_test_zip_exception_02():
|
||||
def test_zip_exception_02():
|
||||
"""
|
||||
Test zip: zip datasets with duplicate column name
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue