forked from mindspore-Ecosystem/mindspore
!3067 Cleanup dataset UT: Remove unneeded tf data files and tests
Merge pull request !3067 from cathwong/ckw_dataset_ut_cleanup6
This commit is contained in:
commit
ba0143402c
|
@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
// Creating TFReaderOp
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
|
|
|
@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
// Creating TFReaderOp
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
|
@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
|
|||
MS_LOG(INFO) << "UT test TestZipRepeat.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,11 +0,0 @@
|
|||
{
|
||||
"datasetType": "TF",
|
||||
"numRows": 3,
|
||||
"columns": {
|
||||
"label": {
|
||||
"type": "int64",
|
||||
"rank": 1,
|
||||
"t_impl": "flex"
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,11 +0,0 @@
|
|||
{
|
||||
"datasetType": "TF",
|
||||
"numRows": 3,
|
||||
"columns": {
|
||||
"image": {
|
||||
"type": "uint8",
|
||||
"rank": 1,
|
||||
"t_impl": "cvmat"
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,204 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def test_case_repeat():
|
||||
"""
|
||||
a simple repeat operation.
|
||||
"""
|
||||
logger.info("Test Simple Repeat")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is: {}".format(item["image"]))
|
||||
logger.info("label is: {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
|
||||
|
||||
def test_case_shuffle():
|
||||
"""
|
||||
a simple shuffle operation.
|
||||
"""
|
||||
logger.info("Test Simple Shuffle")
|
||||
# define parameters
|
||||
buffer_size = 8
|
||||
seed = 10
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("image is: {}".format(item["image"]))
|
||||
logger.info("label is: {}".format(item["label"]))
|
||||
|
||||
|
||||
def test_case_0():
|
||||
"""
|
||||
Test Repeat then Shuffle
|
||||
"""
|
||||
logger.info("Test Repeat then Shuffle")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
buffer_size = 7
|
||||
seed = 9
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is: {}".format(item["image"]))
|
||||
logger.info("label is: {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
|
||||
|
||||
def test_case_0_reverse():
|
||||
"""
|
||||
Test Shuffle then Repeat
|
||||
"""
|
||||
logger.info("Test Shuffle then Repeat")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
buffer_size = 10
|
||||
seed = 9
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is: {}".format(item["image"]))
|
||||
logger.info("label is: {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
|
||||
|
||||
def test_case_3():
|
||||
"""
|
||||
Test Map
|
||||
"""
|
||||
logger.info("Test Map Rescale and Resize, then Shuffle")
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
# define data augmentation parameters
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
resize_height, resize_width = 224, 224
|
||||
|
||||
# define map operations
|
||||
decode_op = vision.Decode()
|
||||
rescale_op = vision.Rescale(rescale, shift)
|
||||
# resize_op = vision.Resize(resize_height, resize_width,
|
||||
# InterpolationMode.DE_INTER_LINEAR) # Bilinear mode
|
||||
resize_op = vision.Resize((resize_height, resize_width))
|
||||
|
||||
# apply map operations on images
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=rescale_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
||||
|
||||
# # apply ont-hot encoding on labels
|
||||
num_classes = 4
|
||||
one_hot_encode = data_trans.OneHot(num_classes) # num_classes is input argument
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
|
||||
#
|
||||
# # apply Datasets
|
||||
buffer_size = 100
|
||||
seed = 10
|
||||
batch_size = 2
|
||||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size) # 10000 as in imageNet train script
|
||||
data1 = data1.batch(batch_size, drop_remainder=True)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is: {}".format(item["image"]))
|
||||
logger.info("label is: {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info('===========now test Repeat============')
|
||||
# logger.info('Simple Repeat')
|
||||
test_case_repeat()
|
||||
logger.info('\n')
|
||||
|
||||
logger.info('===========now test Shuffle===========')
|
||||
# logger.info('Simple Shuffle')
|
||||
test_case_shuffle()
|
||||
logger.info('\n')
|
||||
|
||||
# Note: cannot work with different shapes, hence not for image
|
||||
# logger.info('===========now test Batch=============')
|
||||
# # logger.info('Simple Batch')
|
||||
# test_case_batch()
|
||||
# logger.info('\n')
|
||||
|
||||
logger.info('===========now test case 0============')
|
||||
# logger.info('Repeat then Shuffle')
|
||||
test_case_0()
|
||||
logger.info('\n')
|
||||
|
||||
logger.info('===========now test case 0 reverse============')
|
||||
# # logger.info('Shuffle then Repeat')
|
||||
test_case_0_reverse()
|
||||
logger.info('\n')
|
||||
|
||||
# logger.info('===========now test case 1============')
|
||||
# # logger.info('Repeat with Batch')
|
||||
# test_case_1()
|
||||
# logger.info('\n')
|
||||
|
||||
# logger.info('===========now test case 2============')
|
||||
# # logger.info('Batch with Shuffle')
|
||||
# test_case_2()
|
||||
# logger.info('\n')
|
||||
|
||||
# for image augmentation only
|
||||
logger.info('===========now test case 3============')
|
||||
logger.info('Map then Shuffle')
|
||||
test_case_3()
|
||||
logger.info('\n')
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
|
||||
|
||||
|
||||
def test_tf_file_normal():
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(1)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info('=======test normal=======')
|
||||
test_tf_file_normal()
|
|
@ -13,12 +13,13 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing the one_hot op in DE
|
||||
Testing the OneHot Op
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
from util import diff_mse
|
||||
|
||||
|
@ -37,15 +38,15 @@ def one_hot(index, depth):
|
|||
|
||||
def test_one_hot():
|
||||
"""
|
||||
Test one_hot
|
||||
Test OneHot Tensor Operator
|
||||
"""
|
||||
logger.info("Test one_hot")
|
||||
logger.info("test_one_hot")
|
||||
|
||||
depth = 10
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
one_hot_op = data_trans.OneHot(depth)
|
||||
one_hot_op = data_trans.OneHot(num_classes=depth)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"])
|
||||
|
||||
# Second dataset
|
||||
|
@ -58,8 +59,54 @@ def test_one_hot():
|
|||
label2 = one_hot(item2["label"][0], depth)
|
||||
mse = diff_mse(label1, label2)
|
||||
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
|
||||
assert mse == 0
|
||||
num_iter += 1
|
||||
assert num_iter == 3
|
||||
|
||||
def test_one_hot_post_aug():
|
||||
"""
|
||||
Test One Hot Encoding after Multiple Data Augmentation Operators
|
||||
"""
|
||||
logger.info("test_one_hot_post_aug")
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
# Define data augmentation parameters
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
resize_height, resize_width = 224, 224
|
||||
|
||||
# Define map operations
|
||||
decode_op = c_vision.Decode()
|
||||
rescale_op = c_vision.Rescale(rescale, shift)
|
||||
resize_op = c_vision.Resize((resize_height, resize_width))
|
||||
|
||||
# Apply map operations on images
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=rescale_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
||||
|
||||
# Apply one-hot encoding on labels
|
||||
depth = 4
|
||||
one_hot_encode = data_trans.OneHot(depth)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
|
||||
|
||||
# Apply datasets ops
|
||||
buffer_size = 100
|
||||
seed = 10
|
||||
batch_size = 2
|
||||
ds.config.set_seed(seed)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
data1 = data1.batch(batch_size, drop_remainder=True)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("image is: {}".format(item["image"]))
|
||||
logger.info("label is: {}".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_one_hot()
|
||||
test_one_hot_post_aug()
|
||||
|
|
|
@ -12,25 +12,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test Repeat Op
|
||||
"""
|
||||
import numpy as np
|
||||
from util import save_and_check
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_dict
|
||||
|
||||
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
|
||||
"col_sint16", "col_sint32", "col_sint64"]
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_tf_repeat_01():
|
||||
"""
|
||||
|
@ -39,14 +38,13 @@ def test_tf_repeat_01():
|
|||
logger.info("Test Simple Repeat")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
parameters = {"params": {'repeat_count': repeat_count}}
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
filename = "repeat_result.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_tf_repeat_02():
|
||||
|
@ -99,14 +97,13 @@ def test_tf_repeat_04():
|
|||
logger.info("Test Simple Repeat Column List")
|
||||
# define parameters
|
||||
repeat_count = 2
|
||||
parameters = {"params": {'repeat_count': repeat_count}}
|
||||
columns_list = ["col_sint64", "col_sint32"]
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
|
||||
filename = "repeat_list_result.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def generator():
|
||||
|
@ -115,6 +112,7 @@ def generator():
|
|||
|
||||
|
||||
def test_nested_repeat1():
|
||||
logger.info("test_nested_repeat1")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
@ -126,6 +124,7 @@ def test_nested_repeat1():
|
|||
|
||||
|
||||
def test_nested_repeat2():
|
||||
logger.info("test_nested_repeat2")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(1)
|
||||
data = data.repeat(1)
|
||||
|
@ -137,6 +136,7 @@ def test_nested_repeat2():
|
|||
|
||||
|
||||
def test_nested_repeat3():
|
||||
logger.info("test_nested_repeat3")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(1)
|
||||
data = data.repeat(2)
|
||||
|
@ -148,6 +148,7 @@ def test_nested_repeat3():
|
|||
|
||||
|
||||
def test_nested_repeat4():
|
||||
logger.info("test_nested_repeat4")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(1)
|
||||
|
@ -159,6 +160,7 @@ def test_nested_repeat4():
|
|||
|
||||
|
||||
def test_nested_repeat5():
|
||||
logger.info("test_nested_repeat5")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.batch(3)
|
||||
data = data.repeat(2)
|
||||
|
@ -171,6 +173,7 @@ def test_nested_repeat5():
|
|||
|
||||
|
||||
def test_nested_repeat6():
|
||||
logger.info("test_nested_repeat6")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.batch(3)
|
||||
|
@ -183,6 +186,7 @@ def test_nested_repeat6():
|
|||
|
||||
|
||||
def test_nested_repeat7():
|
||||
logger.info("test_nested_repeat7")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
@ -195,6 +199,7 @@ def test_nested_repeat7():
|
|||
|
||||
|
||||
def test_nested_repeat8():
|
||||
logger.info("test_nested_repeat8")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.batch(2, drop_remainder=False)
|
||||
data = data.repeat(2)
|
||||
|
@ -210,6 +215,7 @@ def test_nested_repeat8():
|
|||
|
||||
|
||||
def test_nested_repeat9():
|
||||
logger.info("test_nested_repeat9")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat()
|
||||
data = data.repeat(3)
|
||||
|
@ -221,6 +227,7 @@ def test_nested_repeat9():
|
|||
|
||||
|
||||
def test_nested_repeat10():
|
||||
logger.info("test_nested_repeat10")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(3)
|
||||
data = data.repeat()
|
||||
|
@ -232,6 +239,7 @@ def test_nested_repeat10():
|
|||
|
||||
|
||||
def test_nested_repeat11():
|
||||
logger.info("test_nested_repeat11")
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
|
|
@ -12,21 +12,30 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test TFRecordDataset Ops
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
from util import save_and_check
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_dict
|
||||
|
||||
FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
|
||||
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_case_tf_shape():
|
||||
def test_tfrecord_shape():
|
||||
logger.info("test_tfrecord_shape")
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
ds1 = ds1.batch(2)
|
||||
|
@ -36,7 +45,8 @@ def test_case_tf_shape():
|
|||
assert len(output_shape[-1]) == 1
|
||||
|
||||
|
||||
def test_case_tf_read_all_dataset():
|
||||
def test_tfrecord_read_all_dataset():
|
||||
logger.info("test_tfrecord_read_all_dataset")
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
assert ds1.get_dataset_size() == 12
|
||||
|
@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
|
|||
assert count == 12
|
||||
|
||||
|
||||
def test_case_num_samples():
|
||||
def test_tfrecord_num_samples():
|
||||
logger.info("test_tfrecord_num_samples")
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
|
||||
assert ds1.get_dataset_size() == 8
|
||||
|
@ -56,7 +67,8 @@ def test_case_num_samples():
|
|||
assert count == 8
|
||||
|
||||
|
||||
def test_case_num_samples2():
|
||||
def test_tfrecord_num_samples2():
|
||||
logger.info("test_tfrecord_num_samples2")
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
|
||||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
assert ds1.get_dataset_size() == 7
|
||||
|
@ -66,42 +78,41 @@ def test_case_num_samples2():
|
|||
assert count == 7
|
||||
|
||||
|
||||
def test_case_tf_shape_2():
|
||||
def test_tfrecord_shape2():
|
||||
logger.info("test_tfrecord_shape2")
|
||||
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
|
||||
ds1 = ds1.batch(2)
|
||||
output_shape = ds1.output_shapes()
|
||||
assert len(output_shape[-1]) == 2
|
||||
|
||||
|
||||
def test_case_tf_file():
|
||||
logger.info("reading data from: {}".format(FILES[0]))
|
||||
parameters = {"params": {}}
|
||||
def test_tfrecord_files_basic():
|
||||
logger.info("test_tfrecord_files_basic")
|
||||
|
||||
data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
filename = "tfreader_result.npz"
|
||||
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
filename = "tfrecord_files_basic.npz"
|
||||
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_tf_file_no_schema():
|
||||
logger.info("reading data from: {}".format(FILES[0]))
|
||||
parameters = {"params": {}}
|
||||
def test_tfrecord_no_schema():
|
||||
logger.info("test_tfrecord_no_schema")
|
||||
|
||||
data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
|
||||
filename = "tf_file_no_schema.npz"
|
||||
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
filename = "tfrecord_no_schema.npz"
|
||||
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_tf_file_pad():
|
||||
logger.info("reading data from: {}".format(FILES[0]))
|
||||
parameters = {"params": {}}
|
||||
def test_tfrecord_pad():
|
||||
logger.info("test_tfrecord_pad")
|
||||
|
||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
|
||||
data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
|
||||
filename = "tf_file_padBytes10.npz"
|
||||
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
filename = "tfrecord_pad_bytes10.npz"
|
||||
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_tf_files():
|
||||
def test_tfrecord_read_files():
|
||||
logger.info("test_tfrecord_read_files")
|
||||
pattern = DATASET_ROOT + "/test.data"
|
||||
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
assert sum([1 for _ in data]) == 12
|
||||
|
@ -123,7 +134,19 @@ def test_tf_files():
|
|||
assert sum([1 for _ in data]) == 24
|
||||
|
||||
|
||||
def test_tf_record_schema():
|
||||
def test_tfrecord_multi_files():
|
||||
logger.info("test_tfrecord_multi_files")
|
||||
data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
|
||||
data1 = data1.repeat(1)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_tfrecord_schema():
|
||||
logger.info("test_tfrecord_schema")
|
||||
schema = ds.Schema()
|
||||
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
|
||||
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
|
||||
|
@ -142,7 +165,8 @@ def test_tf_record_schema():
|
|||
assert np.array_equal(t1, t2)
|
||||
|
||||
|
||||
def test_tf_record_shuffle():
|
||||
def test_tfrecord_shuffle():
|
||||
logger.info("test_tfrecord_shuffle")
|
||||
ds.config.set_seed(1)
|
||||
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
|
@ -153,7 +177,8 @@ def test_tf_record_shuffle():
|
|||
assert np.array_equal(t1, t2)
|
||||
|
||||
|
||||
def test_tf_record_shard():
|
||||
def test_tfrecord_shard():
|
||||
logger.info("test_tfrecord_shard")
|
||||
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
|
||||
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
|
||||
|
||||
|
@ -181,7 +206,8 @@ def test_tf_record_shard():
|
|||
assert set(worker2_res) == set(worker1_res)
|
||||
|
||||
|
||||
def test_tf_shard_equal_rows():
|
||||
def test_tfrecord_shard_equal_rows():
|
||||
logger.info("test_tfrecord_shard_equal_rows")
|
||||
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
|
||||
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
|
||||
|
||||
|
@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
|
|||
assert len(worker4_res) == 40
|
||||
|
||||
|
||||
def test_case_tf_file_no_schema_columns_list():
|
||||
def test_tfrecord_no_schema_columns_list():
|
||||
logger.info("test_tfrecord_no_schema_columns_list")
|
||||
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
|
||||
row = data.create_dict_iterator().get_next()
|
||||
assert row["col_sint16"] == [-32768]
|
||||
|
@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
|
|||
assert "col_sint32" in str(info.value)
|
||||
|
||||
|
||||
def test_tf_record_schema_columns_list():
|
||||
def test_tfrecord_schema_columns_list():
|
||||
logger.info("test_tfrecord_schema_columns_list")
|
||||
schema = ds.Schema()
|
||||
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
|
||||
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
|
||||
|
@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
|
|||
assert "col_sint32" in str(info.value)
|
||||
|
||||
|
||||
def test_case_invalid_files():
|
||||
def test_tfrecord_invalid_files():
|
||||
logger.info("test_tfrecord_invalid_files")
|
||||
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
|
||||
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
|
||||
files = [invalid_file, valid_file, SCHEMA_FILE]
|
||||
|
@ -266,19 +295,20 @@ def test_case_invalid_files():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_case_tf_shape()
|
||||
test_case_tf_read_all_dataset()
|
||||
test_case_num_samples()
|
||||
test_case_num_samples2()
|
||||
test_case_tf_shape_2()
|
||||
test_case_tf_file()
|
||||
test_case_tf_file_no_schema()
|
||||
test_case_tf_file_pad()
|
||||
test_tf_files()
|
||||
test_tf_record_schema()
|
||||
test_tf_record_shuffle()
|
||||
test_tf_record_shard()
|
||||
test_tf_shard_equal_rows()
|
||||
test_case_tf_file_no_schema_columns_list()
|
||||
test_tf_record_schema_columns_list()
|
||||
test_case_invalid_files()
|
||||
test_tfrecord_shape()
|
||||
test_tfrecord_read_all_dataset()
|
||||
test_tfrecord_num_samples()
|
||||
test_tfrecord_num_samples2()
|
||||
test_tfrecord_shape2()
|
||||
test_tfrecord_files_basic()
|
||||
test_tfrecord_no_schema()
|
||||
test_tfrecord_pad()
|
||||
test_tfrecord_read_files()
|
||||
test_tfrecord_multi_files()
|
||||
test_tfrecord_schema()
|
||||
test_tfrecord_shuffle()
|
||||
test_tfrecord_shard()
|
||||
test_tfrecord_shard_equal_rows()
|
||||
test_tfrecord_no_schema_columns_list()
|
||||
test_tfrecord_schema_columns_list()
|
||||
test_tfrecord_invalid_files()
|
||||
|
|
Loading…
Reference in New Issue