10:00 26/5 clean pylint

This commit is contained in:
Yang 2020-05-22 14:16:07 +08:00
parent 93fc82b8f7
commit abca62f407
43 changed files with 217 additions and 239 deletions

View File

@ -34,7 +34,7 @@ def use_filereader(mindrecord):
num_consumer=4, num_consumer=4,
columns=columns_list) columns=columns_list)
num_iter = 0 num_iter = 0
for index, item in enumerate(reader.get_next()): for _, _ in enumerate(reader.get_next()):
num_iter += 1 num_iter += 1
print_log(num_iter) print_log(num_iter)
end = time.time() end = time.time()
@ -48,7 +48,7 @@ def use_minddataset(mindrecord):
columns_list=columns_list, columns_list=columns_list,
num_parallel_workers=4) num_parallel_workers=4)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
print_log(num_iter) print_log(num_iter)
end = time.time() end = time.time()
@ -64,7 +64,7 @@ def use_tfrecorddataset(tfrecord):
shuffle=ds.Shuffle.GLOBAL) shuffle=ds.Shuffle.GLOBAL)
data_set = data_set.shuffle(10000) data_set = data_set.shuffle(10000)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
print_log(num_iter) print_log(num_iter)
end = time.time() end = time.time()
@ -87,7 +87,7 @@ def use_tensorflow_tfrecorddataset(tfrecord):
num_parallel_reads=4) num_parallel_reads=4)
data_set = data_set.map(_parse_record, num_parallel_calls=4) data_set = data_set.map(_parse_record, num_parallel_calls=4)
num_iter = 0 num_iter = 0
for item in data_set.__iter__(): for _ in data_set.__iter__():
num_iter += 1 num_iter += 1
print_log(num_iter) print_log(num_iter)
end = time.time() end = time.time()
@ -96,18 +96,18 @@ def use_tensorflow_tfrecorddataset(tfrecord):
if __name__ == '__main__': if __name__ == '__main__':
# use MindDataset # use MindDataset
mindrecord = './imagenet.mindrecord00' mindrecord_test = './imagenet.mindrecord00'
use_minddataset(mindrecord) use_minddataset(mindrecord_test)
# use TFRecordDataset # use TFRecordDataset
tfrecord = ['imagenet.tfrecord00', 'imagenet.tfrecord01', 'imagenet.tfrecord02', 'imagenet.tfrecord03', tfrecord_test = ['imagenet.tfrecord00', 'imagenet.tfrecord01', 'imagenet.tfrecord02', 'imagenet.tfrecord03',
'imagenet.tfrecord04', 'imagenet.tfrecord05', 'imagenet.tfrecord06', 'imagenet.tfrecord07', 'imagenet.tfrecord04', 'imagenet.tfrecord05', 'imagenet.tfrecord06', 'imagenet.tfrecord07',
'imagenet.tfrecord08', 'imagenet.tfrecord09', 'imagenet.tfrecord10', 'imagenet.tfrecord11', 'imagenet.tfrecord08', 'imagenet.tfrecord09', 'imagenet.tfrecord10', 'imagenet.tfrecord11',
'imagenet.tfrecord12', 'imagenet.tfrecord13', 'imagenet.tfrecord14', 'imagenet.tfrecord15'] 'imagenet.tfrecord12', 'imagenet.tfrecord13', 'imagenet.tfrecord14', 'imagenet.tfrecord15']
use_tfrecorddataset(tfrecord) use_tfrecorddataset(tfrecord_test)
# use TensorFlow TFRecordDataset # use TensorFlow TFRecordDataset
use_tensorflow_tfrecorddataset(tfrecord) use_tensorflow_tfrecorddataset(tfrecord_test)
# use FileReader # use FileReader
# use_filereader(mindrecord) # use_filereader(mindrecord)

View File

@ -29,7 +29,7 @@ def test_case_0():
# apply dataset operations # apply dataset operations
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds1 = ds1.map(input_column_names=col, output_column_names="out", operation=(lambda x: x + x)) ds1 = ds1.map(input_columns=col, output_columns="out", operations=(lambda x: x + x))
print("************** Output Tensor *****************") print("************** Output Tensor *****************")
for data in ds1.create_dict_iterator(): # each data is a dictionary for data in ds1.create_dict_iterator(): # each data is a dictionary
@ -49,7 +49,7 @@ def test_case_1():
# apply dataset operations # apply dataset operations
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds1 = ds1.map(input_column_names=col, output_column_names=["out0", "out1"], operation=(lambda x: (x, x + x))) ds1 = ds1.map(input_columns=col, output_columns=["out0", "out1"], operations=(lambda x: (x, x + x)))
print("************** Output Tensor *****************") print("************** Output Tensor *****************")
for data in ds1.create_dict_iterator(): # each data is a dictionary for data in ds1.create_dict_iterator(): # each data is a dictionary
@ -72,7 +72,7 @@ def test_case_2():
# apply dataset operations # apply dataset operations
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds1 = ds1.map(input_column_names=col, output_column_names="out", operation=(lambda x, y: x + y)) ds1 = ds1.map(input_columns=col, output_columns="out", operations=(lambda x, y: x + y))
print("************** Output Tensor *****************") print("************** Output Tensor *****************")
for data in ds1.create_dict_iterator(): # each data is a dictionary for data in ds1.create_dict_iterator(): # each data is a dictionary
@ -93,8 +93,8 @@ def test_case_3():
# apply dataset operations # apply dataset operations
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds1 = ds1.map(input_column_names=col, output_column_names=["out0", "out1", "out2"], ds1 = ds1.map(input_columns=col, output_columns=["out0", "out1", "out2"],
operation=(lambda x, y: (x, x + y, x + x + y))) operations=(lambda x, y: (x, x + y, x + x + y)))
print("************** Output Tensor *****************") print("************** Output Tensor *****************")
for data in ds1.create_dict_iterator(): # each data is a dictionary for data in ds1.create_dict_iterator(): # each data is a dictionary
@ -119,8 +119,8 @@ def test_case_4():
# apply dataset operations # apply dataset operations
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds1 = ds1.map(input_column_names=col, output_column_names=["out0", "out1", "out2"], num_parallel_workers=4, ds1 = ds1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
operation=(lambda x, y: (x, x + y, x + x + y))) operations=(lambda x, y: (x, x + y, x + x + y)))
print("************** Output Tensor *****************") print("************** Output Tensor *****************")
for data in ds1.create_dict_iterator(): # each data is a dictionary for data in ds1.create_dict_iterator(): # each data is a dictionary

View File

@ -22,11 +22,11 @@ def create_data_cache_dir():
cwd = os.getcwd() cwd = os.getcwd()
target_directory = os.path.join(cwd, "data_cache") target_directory = os.path.join(cwd, "data_cache")
try: try:
if not (os.path.exists(target_directory)): if not os.path.exists(target_directory):
os.mkdir(target_directory) os.mkdir(target_directory)
except OSError: except OSError:
print("Creation of the directory %s failed" % target_directory) print("Creation of the directory %s failed" % target_directory)
return target_directory; return target_directory
def download_and_uncompress(files, source_url, target_directory, is_tar=False): def download_and_uncompress(files, source_url, target_directory, is_tar=False):
@ -53,13 +53,13 @@ def download_and_uncompress(files, source_url, target_directory, is_tar=False):
def download_mnist(target_directory=None): def download_mnist(target_directory=None):
if target_directory == None: if target_directory is None:
target_directory = create_data_cache_dir() target_directory = create_data_cache_dir()
##create mnst directory ##create mnst directory
target_directory = os.path.join(target_directory, "mnist") target_directory = os.path.join(target_directory, "mnist")
try: try:
if not (os.path.exists(target_directory)): if not os.path.exists(target_directory):
os.mkdir(target_directory) os.mkdir(target_directory)
except OSError: except OSError:
print("Creation of the directory %s failed" % target_directory) print("Creation of the directory %s failed" % target_directory)
@ -78,14 +78,14 @@ CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
def download_cifar(target_directory, files, directory_from_tar): def download_cifar(target_directory, files, directory_from_tar):
if target_directory == None: if target_directory is None:
target_directory = create_data_cache_dir() target_directory = create_data_cache_dir()
download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True) download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True)
##if target dir was specify move data from directory created by tar ##if target dir was specify move data from directory created by tar
##and put data into target dir ##and put data into target dir
if target_directory != None: if target_directory is not None:
tar_dir_full_path = os.path.join(target_directory, directory_from_tar) tar_dir_full_path = os.path.join(target_directory, directory_from_tar)
all_files = os.path.join(tar_dir_full_path, "*") all_files = os.path.join(tar_dir_full_path, "*")
cmd = "mv " + all_files + " " + target_directory cmd = "mv " + all_files + " " + target_directory

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import mindspore._c_dataengine as cde
import numpy as np import numpy as np
import mindspore._c_dataengine as cde
def test_shape(): def test_shape():
x = [2, 3] x = [2, 3]

View File

@ -221,7 +221,7 @@ def test_apply_exception_case():
try: try:
data2 = data1.apply(dataset_fn) data2 = data1.apply(dataset_fn)
data3 = data1.apply(dataset_fn) data3 = data1.apply(dataset_fn)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
pass pass
assert False assert False
except ValueError: except ValueError:

View File

@ -35,10 +35,10 @@ def test_case_dataset_cifar10():
data1 = ds.Cifar10Dataset(DATA_DIR_10, 100) data1 = ds.Cifar10Dataset(DATA_DIR_10, 100)
num_iter = 0 num_iter = 0
for item in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
# in this example, each dictionary has keys "image" and "label" # in this example, each dictionary has keys "image" and "label"
num_iter += 1 num_iter += 1
assert (num_iter == 100) assert num_iter == 100
def test_case_dataset_cifar100(): def test_case_dataset_cifar100():
@ -50,10 +50,10 @@ def test_case_dataset_cifar100():
data1 = ds.Cifar100Dataset(DATA_DIR_100, 100) data1 = ds.Cifar100Dataset(DATA_DIR_100, 100)
num_iter = 0 num_iter = 0
for item in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
# in this example, each dictionary has keys "image" and "label" # in this example, each dictionary has keys "image" and "label"
num_iter += 1 num_iter += 1
assert (num_iter == 100) assert num_iter == 100
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -15,10 +15,10 @@
""" """
Testing configuration manager Testing configuration manager
""" """
import os
import filecmp import filecmp
import glob import glob
import numpy as np import numpy as np
import os
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
@ -89,7 +89,7 @@ def test_pipeline():
ds.serialize(data2, "testpipeline2.json") ds.serialize(data2, "testpipeline2.json")
# check that the generated output is different # check that the generated output is different
assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) assert filecmp.cmp('testpipeline.json', 'testpipeline2.json')
# this test passes currently because our num_parallel_workers don't get updated. # this test passes currently because our num_parallel_workers don't get updated.

View File

@ -33,9 +33,9 @@ def test_celeba_dataset_label():
logger.info("----------attr--------") logger.info("----------attr--------")
logger.info(item["attr"]) logger.info(item["attr"])
for index in range(len(expect_labels[count])): for index in range(len(expect_labels[count])):
assert (item["attr"][index] == expect_labels[count][index]) assert item["attr"][index] == expect_labels[count][index]
count = count + 1 count = count + 1
assert (count == 2) assert count == 2
def test_celeba_dataset_op(): def test_celeba_dataset_op():
@ -54,7 +54,7 @@ def test_celeba_dataset_op():
logger.info("----------image--------") logger.info("----------image--------")
logger.info(item["image"]) logger.info(item["image"])
count = count + 1 count = count + 1
assert (count == 4) assert count == 4
def test_celeba_dataset_ext(): def test_celeba_dataset_ext():
@ -69,9 +69,9 @@ def test_celeba_dataset_ext():
logger.info("----------attr--------") logger.info("----------attr--------")
logger.info(item["attr"]) logger.info(item["attr"])
for index in range(len(expect_labels[count])): for index in range(len(expect_labels[count])):
assert (item["attr"][index] == expect_labels[count][index]) assert item["attr"][index] == expect_labels[count][index]
count = count + 1 count = count + 1
assert (count == 1) assert count == 1
def test_celeba_dataset_distribute(): def test_celeba_dataset_distribute():
@ -83,7 +83,7 @@ def test_celeba_dataset_distribute():
logger.info("----------attr--------") logger.info("----------attr--------")
logger.info(item["attr"]) logger.info(item["attr"])
count = count + 1 count = count + 1
assert (count == 1) assert count == 1
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -35,7 +35,7 @@ def test_imagefolder_basic():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 44) assert num_iter == 44
def test_imagefolder_numsamples(): def test_imagefolder_numsamples():
@ -55,7 +55,7 @@ def test_imagefolder_numsamples():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 10) assert num_iter == 10
def test_imagefolder_numshards(): def test_imagefolder_numshards():
@ -75,7 +75,7 @@ def test_imagefolder_numshards():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 11) assert num_iter == 11
def test_imagefolder_shardid(): def test_imagefolder_shardid():
@ -95,7 +95,7 @@ def test_imagefolder_shardid():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 11) assert num_iter == 11
def test_imagefolder_noshuffle(): def test_imagefolder_noshuffle():
@ -115,7 +115,7 @@ def test_imagefolder_noshuffle():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 44) assert num_iter == 44
def test_imagefolder_extrashuffle(): def test_imagefolder_extrashuffle():
@ -136,7 +136,7 @@ def test_imagefolder_extrashuffle():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 88) assert num_iter == 88
def test_imagefolder_classindex(): def test_imagefolder_classindex():
@ -157,11 +157,11 @@ def test_imagefolder_classindex():
# in this example, each dictionary has keys "image" and "label" # in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"])) logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"])) logger.info("label is {}".format(item["label"]))
assert (item["label"] == golden[num_iter]) assert item["label"] == golden[num_iter]
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 22) assert num_iter == 22
def test_imagefolder_negative_classindex(): def test_imagefolder_negative_classindex():
@ -182,11 +182,11 @@ def test_imagefolder_negative_classindex():
# in this example, each dictionary has keys "image" and "label" # in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"])) logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"])) logger.info("label is {}".format(item["label"]))
assert (item["label"] == golden[num_iter]) assert item["label"] == golden[num_iter]
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 22) assert num_iter == 22
def test_imagefolder_extensions(): def test_imagefolder_extensions():
@ -207,7 +207,7 @@ def test_imagefolder_extensions():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 44) assert num_iter == 44
def test_imagefolder_decode(): def test_imagefolder_decode():
@ -228,7 +228,7 @@ def test_imagefolder_decode():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 44) assert num_iter == 44
def test_sequential_sampler(): def test_sequential_sampler():
@ -255,7 +255,7 @@ def test_sequential_sampler():
num_iter += 1 num_iter += 1
logger.info("Result: {}".format(result)) logger.info("Result: {}".format(result))
assert (result == golden) assert result == golden
def test_random_sampler(): def test_random_sampler():
@ -276,7 +276,7 @@ def test_random_sampler():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 44) assert num_iter == 44
def test_distributed_sampler(): def test_distributed_sampler():
@ -297,7 +297,7 @@ def test_distributed_sampler():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 5) assert num_iter == 5
def test_pk_sampler(): def test_pk_sampler():
@ -318,7 +318,7 @@ def test_pk_sampler():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 12) assert num_iter == 12
def test_subset_random_sampler(): def test_subset_random_sampler():
@ -340,7 +340,7 @@ def test_subset_random_sampler():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 12) assert num_iter == 12
def test_weighted_random_sampler(): def test_weighted_random_sampler():
@ -362,7 +362,7 @@ def test_weighted_random_sampler():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 11) assert num_iter == 11
def test_imagefolder_rename(): def test_imagefolder_rename():
@ -382,7 +382,7 @@ def test_imagefolder_rename():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 10) assert num_iter == 10
data1 = data1.rename(input_columns=["image"], output_columns="image2") data1 = data1.rename(input_columns=["image"], output_columns="image2")
@ -394,7 +394,7 @@ def test_imagefolder_rename():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 10) assert num_iter == 10
def test_imagefolder_zip(): def test_imagefolder_zip():
@ -419,7 +419,7 @@ def test_imagefolder_zip():
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert (num_iter == 10) assert num_iter == 10
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
@ -30,7 +28,7 @@ def test_tf_file_normal():
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(1) data1 = data1.repeat(1)
num_iter = 0 num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary for _ in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1 num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))

View File

@ -16,7 +16,6 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
DATA_FILE = "../data/dataset/testManifestData/test.manifest" DATA_FILE = "../data/dataset/testManifestData/test.manifest"
@ -34,9 +33,9 @@ def test_manifest_dataset_train():
cat_count = cat_count + 1 cat_count = cat_count + 1
elif item["label"].size == 1 and item["label"] == 1: elif item["label"].size == 1 and item["label"] == 1:
dog_count = dog_count + 1 dog_count = dog_count + 1
assert (cat_count == 2) assert cat_count == 2
assert (dog_count == 1) assert dog_count == 1
assert (count == 4) assert count == 4
def test_manifest_dataset_eval(): def test_manifest_dataset_eval():
@ -46,36 +45,36 @@ def test_manifest_dataset_eval():
logger.info("item[image] is {}".format(item["image"])) logger.info("item[image] is {}".format(item["image"]))
count = count + 1 count = count + 1
if item["label"] != 0 and item["label"] != 1: if item["label"] != 0 and item["label"] != 1:
assert (0) assert 0
assert (count == 2) assert count == 2
def test_manifest_dataset_class_index(): def test_manifest_dataset_class_index():
class_indexing = {"dog": 11} class_indexing = {"dog": 11}
data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing) data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing)
out_class_indexing = data.get_class_indexing() out_class_indexing = data.get_class_indexing()
assert (out_class_indexing == {"dog": 11}) assert out_class_indexing == {"dog": 11}
count = 0 count = 0
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
logger.info("item[image] is {}".format(item["image"])) logger.info("item[image] is {}".format(item["image"]))
count = count + 1 count = count + 1
if item["label"] != 11: if item["label"] != 11:
assert (0) assert 0
assert (count == 1) assert count == 1
def test_manifest_dataset_get_class_index(): def test_manifest_dataset_get_class_index():
data = ds.ManifestDataset(DATA_FILE, decode=True) data = ds.ManifestDataset(DATA_FILE, decode=True)
class_indexing = data.get_class_indexing() class_indexing = data.get_class_indexing()
assert (class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}) assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
data = data.shuffle(4) data = data.shuffle(4)
class_indexing = data.get_class_indexing() class_indexing = data.get_class_indexing()
assert (class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}) assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
count = 0 count = 0
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
logger.info("item[image] is {}".format(item["image"])) logger.info("item[image] is {}".format(item["image"]))
count = count + 1 count = count + 1
assert (count == 4) assert count == 4
def test_manifest_dataset_multi_label(): def test_manifest_dataset_multi_label():
@ -83,10 +82,10 @@ def test_manifest_dataset_multi_label():
count = 0 count = 0
expect_label = [1, 0, 0, [0, 2]] expect_label = [1, 0, 0, [0, 2]]
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
assert (item["label"].tolist() == expect_label[count]) assert item["label"].tolist() == expect_label[count]
logger.info("item[image] is {}".format(item["image"])) logger.info("item[image] is {}".format(item["image"]))
count = count + 1 count = count + 1
assert (count == 4) assert count == 4
def multi_label_hot(x): def multi_label_hot(x):
@ -109,7 +108,7 @@ def test_manifest_dataset_multi_label_onehot():
data = data.batch(2) data = data.batch(2)
count = 0 count = 0
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
assert (item["label"].tolist() == expect_label[count]) assert item["label"].tolist() == expect_label[count]
logger.info("item[image] is {}".format(item["image"])) logger.info("item[image] is {}".format(item["image"]))
count = count + 1 count = count + 1

View File

@ -27,7 +27,7 @@ def test_imagefolder_shardings(print_res=False):
res = [] res = []
for item in data1.create_dict_iterator(): # each data is a dictionary for item in data1.create_dict_iterator(): # each data is a dictionary
res.append(item["label"].item()) res.append(item["label"].item())
if (print_res): if print_res:
logger.info("labels of dataset: {}".format(res)) logger.info("labels of dataset: {}".format(res))
return res return res
@ -39,12 +39,12 @@ def test_imagefolder_shardings(print_res=False):
assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows
assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows
# total 22 in dataset rows because of class indexing which takes only 2 folders # total 22 in dataset rows because of class indexing which takes only 2 folders
assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6) assert len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6
assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) assert len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3
# test with repeat # test with repeat
assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3) assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5) assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20) assert len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20
def test_tfrecord_shardings1(print_res=False): def test_tfrecord_shardings1(print_res=False):
@ -176,8 +176,8 @@ def test_voc_shardings(print_res=False):
# then takes the first 2 bc num_samples = 2 # then takes the first 2 bc num_samples = 2
assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4) assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4)
# test that each epoch, each shard_worker returns a different sample # test that each epoch, each shard_worker returns a different sample
assert (len(sharding_config(2, 0, None, True, 1)) == 5) assert len(sharding_config(2, 0, None, True, 1)) == 5
assert (len(set(sharding_config(11, 0, None, True, 10))) > 1) assert len(set(sharding_config(11, 0, None, True, 10))) > 1
def test_cifar10_shardings(print_res=False): def test_cifar10_shardings(print_res=False):
@ -196,8 +196,8 @@ def test_cifar10_shardings(print_res=False):
# 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow # 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow
# therefore, only 2 test cases for now. # therefore, only 2 test cases for now.
assert (sharding_config(10000, 9999, 7, False, 1) == [9]) assert sharding_config(10000, 9999, 7, False, 1) == [9]
assert (sharding_config(10000, 0, 4, False, 3) == [0, 0, 0]) assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0]
def test_cifar100_shardings(print_res=False): def test_cifar100_shardings(print_res=False):

View File

@ -27,7 +27,7 @@ def test_textline_dataset_one_file():
for i in data.create_dict_iterator(): for i in data.create_dict_iterator():
logger.info("{}".format(i["text"])) logger.info("{}".format(i["text"]))
count += 1 count += 1
assert (count == 3) assert count == 3
def test_textline_dataset_all_file(): def test_textline_dataset_all_file():
@ -36,7 +36,7 @@ def test_textline_dataset_all_file():
for i in data.create_dict_iterator(): for i in data.create_dict_iterator():
logger.info("{}".format(i["text"])) logger.info("{}".format(i["text"]))
count += 1 count += 1
assert (count == 5) assert count == 5
def test_textline_dataset_totext(): def test_textline_dataset_totext():
@ -46,8 +46,8 @@ def test_textline_dataset_totext():
line = ["This is a text file.", "Another file.", line = ["This is a text file.", "Another file.",
"Be happy every day.", "End of file.", "Good luck to everyone."] "Be happy every day.", "End of file.", "Good luck to everyone."]
for i in data.create_dict_iterator(): for i in data.create_dict_iterator():
str = i["text"].item().decode("utf8") strs = i["text"].item().decode("utf8")
assert (str == line[count]) assert strs == line[count]
count += 1 count += 1
assert (count == 5) assert (count == 5)
# Restore configuration num_parallel_workers # Restore configuration num_parallel_workers
@ -57,17 +57,17 @@ def test_textline_dataset_totext():
def test_textline_dataset_num_samples(): def test_textline_dataset_num_samples():
data = ds.TextFileDataset(DATA_FILE, num_samples=2) data = ds.TextFileDataset(DATA_FILE, num_samples=2)
count = 0 count = 0
for i in data.create_dict_iterator(): for _ in data.create_dict_iterator():
count += 1 count += 1
assert (count == 2) assert count == 2
def test_textline_dataset_distribution(): def test_textline_dataset_distribution():
data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1) data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1)
count = 0 count = 0
for i in data.create_dict_iterator(): for _ in data.create_dict_iterator():
count += 1 count += 1
assert (count == 3) assert count == 3
def test_textline_dataset_repeat(): def test_textline_dataset_repeat():
@ -78,16 +78,16 @@ def test_textline_dataset_repeat():
"This is a text file.", "Be happy every day.", "Good luck to everyone.", "This is a text file.", "Be happy every day.", "Good luck to everyone.",
"This is a text file.", "Be happy every day.", "Good luck to everyone."] "This is a text file.", "Be happy every day.", "Good luck to everyone."]
for i in data.create_dict_iterator(): for i in data.create_dict_iterator():
str = i["text"].item().decode("utf8") strs = i["text"].item().decode("utf8")
assert (str == line[count]) assert strs == line[count]
count += 1 count += 1
assert (count == 9) assert count == 9
def test_textline_dataset_get_datasetsize(): def test_textline_dataset_get_datasetsize():
data = ds.TextFileDataset(DATA_FILE) data = ds.TextFileDataset(DATA_FILE)
size = data.get_dataset_size() size = data.get_dataset_size()
assert (size == 3) assert size == 3
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -15,9 +15,8 @@
""" """
Testing Decode op in DE Testing Decode op in DE
""" """
import cv2
import numpy as np import numpy as np
from util import diff_mse import cv2
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision

View File

@ -88,7 +88,7 @@ def test_filter_by_generator_with_repeat():
ret_data.append(item["data"]) ret_data.append(item["data"])
assert num_iter == 44 assert num_iter == 44
for i in range(4): for i in range(4):
for ii in range(len(expected_rs)): for ii, _ in enumerate(expected_rs):
index = i * len(expected_rs) + ii index = i * len(expected_rs) + ii
assert ret_data[index] == expected_rs[ii] assert ret_data[index] == expected_rs[ii]
@ -106,7 +106,7 @@ def test_filter_by_generator_with_repeat_after():
ret_data.append(item["data"]) ret_data.append(item["data"])
assert num_iter == 44 assert num_iter == 44
for i in range(4): for i in range(4):
for ii in range(len(expected_rs)): for ii, _ in enumerate(expected_rs):
index = i * len(expected_rs) + ii index = i * len(expected_rs) + ii
assert ret_data[index] == expected_rs[ii] assert ret_data[index] == expected_rs[ii]
@ -167,7 +167,7 @@ def test_filter_by_generator_with_shuffle():
dataset_s = dataset.shuffle(4) dataset_s = dataset.shuffle(4)
dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4) dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4)
num_iter = 0 num_iter = 0
for item in dataset_f.create_dict_iterator(): for _ in dataset_f.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 21 assert num_iter == 21
@ -184,7 +184,7 @@ def test_filter_by_generator_with_shuffle_after():
dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4) dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4)
dataset_s = dataset_f.shuffle(4) dataset_s = dataset_f.shuffle(4)
num_iter = 0 num_iter = 0
for item in dataset_s.create_dict_iterator(): for _ in dataset_s.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 21 assert num_iter == 21
@ -258,8 +258,7 @@ def filter_func_map(col1, col2):
def filter_func_map_part(col1): def filter_func_map_part(col1):
if col1 < 3: if col1 < 3:
return True return True
else: return False
return False
def filter_func_map_all(col1, col2): def filter_func_map_all(col1, col2):
@ -276,7 +275,7 @@ def func_map(data_col1, data_col2):
def func_map_part(data_col1): def func_map_part(data_col1):
return (data_col1) return data_col1
# test with map # test with map
@ -473,7 +472,6 @@ def test_filte_case_dataset_cifar10():
ds.config.load('../data/dataset/declient_filter.cfg') ds.config.load('../data/dataset/declient_filter.cfg')
dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False) dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False)
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
num_iter = 0
for item in dataset_f1.create_dict_iterator(): for item in dataset_f1.create_dict_iterator():
# in this example, each dictionary has keys "image" and "label" # in this example, each dictionary has keys "image" and "label"
assert item["label"] % 3 == 0 assert item["label"] % 3 == 0

View File

@ -184,7 +184,7 @@ def test_case_6():
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
mstype.uint64, mstype.float32, mstype.float64] mstype.uint64, mstype.float32, mstype.float64]
for i in range(len(np_types)): for i, _ in enumerate(np_types):
type_tester_with_type_check(np_types[i], de_types[i]) type_tester_with_type_check(np_types[i], de_types[i])
@ -219,7 +219,7 @@ def test_case_7():
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
mstype.uint64, mstype.float32, mstype.float64] mstype.uint64, mstype.float32, mstype.float64]
for i in range(len(np_types)): for i, _ in enumerate(np_types):
type_tester_with_type_check_2c(np_types[i], [None, de_types[i]]) type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
@ -526,7 +526,7 @@ def test_sequential_sampler():
def test_random_sampler(): def test_random_sampler():
source = [(np.array([x]),) for x in range(64)] source = [(np.array([x]),) for x in range(64)]
ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True) ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
for data in ds1.create_dict_iterator(): # each data is a dictionary for _ in ds1.create_dict_iterator(): # each data is a dictionary
pass pass
@ -611,7 +611,7 @@ def test_schema():
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
mstype.uint64, mstype.float32, mstype.float64] mstype.uint64, mstype.float32, mstype.float64]
for i in range(len(np_types)): for i, _ in enumerate(np_types):
type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
@ -630,8 +630,7 @@ def manual_test_keyborad_interrupt():
return 1024 return 1024
ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2) ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
i = 0 for _ in ds1.create_dict_iterator(): # each data is a dictionary
for data in ds1.create_dict_iterator(): # each data is a dictionary
pass pass

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import copy
import numpy as np import numpy as np
import pytest import pytest

View File

@ -320,7 +320,7 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file):
data = data.shuffle(2) data = data.shuffle(2)
data = data.repeat(9) data = data.repeat(9)
num_iter = 0 num_iter = 0
for item in data.create_dict_iterator(): for _ in data.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 18 assert num_iter == 18
@ -572,7 +572,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
num_readers = 4 num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
assert data_set.get_dataset_size() == 10 assert data_set.get_dataset_size() == 10
for epoch in range(5): for _ in range(5):
num_iter = 0 num_iter = 0
for data in data_set: for data in data_set:
logger.info("data is {}".format(data)) logger.info("data is {}".format(data))
@ -603,7 +603,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_
data_set = data_set.batch(2) data_set = data_set.batch(2)
assert data_set.get_dataset_size() == 5 assert data_set.get_dataset_size() == 5
for epoch in range(5): for _ in range(5):
num_iter = 0 num_iter = 0
for data in data_set: for data in data_set:
logger.info("data is {}".format(data)) logger.info("data is {}".format(data))

View File

@ -91,7 +91,7 @@ def test_invalid_mindrecord():
with pytest.raises(Exception, match="MindRecordOp init failed"): with pytest.raises(Exception, match="MindRecordOp init failed"):
data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers) data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 0 assert num_iter == 0
os.remove('dummy.mindrecord') os.remove('dummy.mindrecord')
@ -105,7 +105,7 @@ def test_minddataset_lack_db():
with pytest.raises(Exception, match="MindRecordOp init failed"): with pytest.raises(Exception, match="MindRecordOp init failed"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 0 assert num_iter == 0
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
@ -119,7 +119,7 @@ def test_cv_minddataset_pk_sample_error_class_column():
with pytest.raises(Exception, match="MindRecordOp launch failed"): with pytest.raises(Exception, match="MindRecordOp launch failed"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))

View File

@ -15,8 +15,8 @@
""" """
This is the test module for mindrecord This is the test module for mindrecord
""" """
import numpy as np
import os import os
import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger

View File

@ -15,16 +15,10 @@
""" """
This is the test module for mindrecord This is the test module for mindrecord
""" """
import collections
import json
import numpy as np
import os import os
import pytest import pytest
import re
import string
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
from mindspore.dataset.transforms.vision import Inter from mindspore.dataset.transforms.vision import Inter
from mindspore.dataset.text import to_str from mindspore.dataset.text import to_str

View File

@ -49,7 +49,7 @@ def test_one_hot_op():
label = data["label"] label = data["label"]
logger.info("label is {}".format(label)) logger.info("label is {}".format(label))
logger.info("golden_label is {}".format(golden_label)) logger.info("golden_label is {}".format(golden_label))
assert (label.all() == golden_label.all()) assert label.all() == golden_label.all()
logger.info("====test one hot op ok====") logger.info("====test one hot op ok====")

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
@ -50,6 +49,7 @@ def get_normalized(image_id):
if num_iter == image_id: if num_iter == image_id:
return normalize_np(image) return normalize_np(image)
num_iter += 1 num_iter += 1
return None
def test_normalize_op(): def test_normalize_op():

View File

@ -19,7 +19,6 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]

View File

@ -15,7 +15,6 @@
""" """
Testing Pad op in DE Testing Pad op in DE
""" """
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from util import diff_mse from util import diff_mse
@ -118,7 +117,7 @@ def test_pad_grayscale():
for shape1, shape2 in zip(dataset_shape_1, dataset_shape_2): for shape1, shape2 in zip(dataset_shape_1, dataset_shape_2):
# validate that the first two dimensions are the same # validate that the first two dimensions are the same
# we have a little inconsistency here because the third dimension is 1 after py_vision.Grayscale # we have a little inconsistency here because the third dimension is 1 after py_vision.Grayscale
assert (shape1[0:1] == shape2[0:1]) assert shape1[0:1] == shape2[0:1]
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np
import time import time
import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
@ -117,8 +117,7 @@ def batch_padding_performance_3d():
data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info) data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
start_time = time.time() start_time = time.time()
num_batches = 0 num_batches = 0
ret = [] for _ in data1.create_dict_iterator():
for data in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -134,7 +133,7 @@ def batch_padding_performance_1d():
data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info) data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
start_time = time.time() start_time = time.time()
num_batches = 0 num_batches = 0
for data in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -150,7 +149,7 @@ def batch_pyfunc_padding_3d():
data1 = data1.batch(batch_size=24, drop_remainder=True) data1 = data1.batch(batch_size=24, drop_remainder=True)
start_time = time.time() start_time = time.time()
num_batches = 0 num_batches = 0
for data in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -165,7 +164,7 @@ def batch_pyfunc_padding_1d():
data1 = data1.batch(batch_size=24, drop_remainder=True) data1 = data1.batch(batch_size=24, drop_remainder=True)
start_time = time.time() start_time = time.time()
num_batches = 0 num_batches = 0
for data in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -197,7 +196,7 @@ def test_pad_via_map():
res_from_map = pad_map_config() res_from_map = pad_map_config()
res_from_batch = pad_batch_config() res_from_batch = pad_batch_config()
assert len(res_from_batch) == len(res_from_batch) assert len(res_from_batch) == len(res_from_batch)
for i in range(len(res_from_map)): for i, _ in enumerate(res_from_map):
assert np.array_equal(res_from_map[i], res_from_batch[i]) assert np.array_equal(res_from_map[i], res_from_batch[i])

View File

@ -15,8 +15,9 @@
""" """
Testing RandomCropAndResize op in DE Testing RandomCropAndResize op in DE
""" """
import cv2
import numpy as np import numpy as np
import cv2
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
import mindspore.dataset.transforms.vision.py_transforms as py_vision import mindspore.dataset.transforms.vision.py_transforms as py_vision
import mindspore.dataset.transforms.vision.utils as mode import mindspore.dataset.transforms.vision.utils as mode

View File

@ -15,9 +15,9 @@
""" """
Testing RandomCropDecodeResize op in DE Testing RandomCropDecodeResize op in DE
""" """
import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import cv2
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from pathlib import Path
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
@ -39,7 +37,7 @@ def test_randomdataset_basic1():
num_iter += 1 num_iter += 1
logger.info("Number of data in ds1: ", num_iter) logger.info("Number of data in ds1: ", num_iter)
assert (num_iter == 200) assert num_iter == 200
# Another simple test # Another simple test
@ -65,7 +63,7 @@ def test_randomdataset_basic2():
num_iter += 1 num_iter += 1
logger.info("Number of data in ds1: ", num_iter) logger.info("Number of data in ds1: ", num_iter)
assert (num_iter == 40) assert num_iter == 40
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -15,9 +15,9 @@
""" """
Testing RandomRotation op in DE Testing RandomRotation op in DE
""" """
import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import cv2
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision

View File

@ -34,7 +34,7 @@ def test_rename():
num_iter = 0 num_iter = 0
for i, item in enumerate(data.create_dict_iterator()): for _, item in enumerate(data.create_dict_iterator()):
logger.info("item[mask] is {}".format(item["masks"])) logger.info("item[mask] is {}".format(item["masks"]))
np.testing.assert_equal(item["masks"], item["input_ids"]) np.testing.assert_equal(item["masks"], item["input_ids"])
logger.info("item[seg_ids] is {}".format(item["seg_ids"])) logger.info("item[seg_ids] is {}".format(item["seg_ids"]))

View File

@ -159,7 +159,7 @@ def test_rgb_hsv_pipeline():
ori_img = data1["image"] ori_img = data1["image"]
cvt_img = data2["image"] cvt_img = data2["image"]
assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0) assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0)
assert (ori_img.shape == cvt_img.shape) assert ori_img.shape == cvt_img.shape
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -57,7 +57,7 @@ def test_imagefolder(remove_json_files=True):
# data1 should still work after saving. # data1 should still work after saving.
ds.serialize(data1, "imagenet_dataset_pipeline.json") ds.serialize(data1, "imagenet_dataset_pipeline.json")
ds1_dict = ds.serialize(data1) ds1_dict = ds.serialize(data1)
assert (validate_jsonfile("imagenet_dataset_pipeline.json") is True) assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
# Print the serialized pipeline to stdout # Print the serialized pipeline to stdout
ds.show(data1) ds.show(data1)
@ -68,8 +68,8 @@ def test_imagefolder(remove_json_files=True):
# Serialize the pipeline we just deserialized. # Serialize the pipeline we just deserialized.
# The content of the json file should be the same to the previous serialize. # The content of the json file should be the same to the previous serialize.
ds.serialize(data2, "imagenet_dataset_pipeline_1.json") ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
assert (validate_jsonfile("imagenet_dataset_pipeline_1.json") is True) assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
assert (filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')) assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
# Deserialize the latest json file again # Deserialize the latest json file again
data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json") data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
@ -78,16 +78,16 @@ def test_imagefolder(remove_json_files=True):
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(), for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
data3.create_dict_iterator(), data4.create_dict_iterator()): data3.create_dict_iterator(), data4.create_dict_iterator()):
assert (np.array_equal(item1['image'], item2['image'])) assert np.array_equal(item1['image'], item2['image'])
assert (np.array_equal(item1['image'], item3['image'])) assert np.array_equal(item1['image'], item3['image'])
assert (np.array_equal(item1['label'], item2['label'])) assert np.array_equal(item1['label'], item2['label'])
assert (np.array_equal(item1['label'], item3['label'])) assert np.array_equal(item1['label'], item3['label'])
assert (np.array_equal(item3['image'], item4['image'])) assert np.array_equal(item3['image'], item4['image'])
assert (np.array_equal(item3['label'], item4['label'])) assert np.array_equal(item3['label'], item4['label'])
num_samples += 1 num_samples += 1
logger.info("Number of data in data1: {}".format(num_samples)) logger.info("Number of data in data1: {}".format(num_samples))
assert (num_samples == 6) assert num_samples == 6
# Remove the generated json file # Remove the generated json file
if remove_json_files: if remove_json_files:
@ -106,26 +106,26 @@ def test_mnist_dataset(remove_json_files=True):
data1 = data1.batch(batch_size=10, drop_remainder=True) data1 = data1.batch(batch_size=10, drop_remainder=True)
ds.serialize(data1, "mnist_dataset_pipeline.json") ds.serialize(data1, "mnist_dataset_pipeline.json")
assert (validate_jsonfile("mnist_dataset_pipeline.json") is True) assert validate_jsonfile("mnist_dataset_pipeline.json") is True
data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json") data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
ds.serialize(data2, "mnist_dataset_pipeline_1.json") ds.serialize(data2, "mnist_dataset_pipeline_1.json")
assert (validate_jsonfile("mnist_dataset_pipeline_1.json") is True) assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
assert (filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')) assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json") data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
num = 0 num = 0
for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(), for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
data3.create_dict_iterator()): data3.create_dict_iterator()):
assert (np.array_equal(data1['image'], data2['image'])) assert np.array_equal(data1['image'], data2['image'])
assert (np.array_equal(data1['image'], data3['image'])) assert np.array_equal(data1['image'], data3['image'])
assert (np.array_equal(data1['label'], data2['label'])) assert np.array_equal(data1['label'], data2['label'])
assert (np.array_equal(data1['label'], data3['label'])) assert np.array_equal(data1['label'], data3['label'])
num += 1 num += 1
logger.info("mnist total num samples is {}".format(str(num))) logger.info("mnist total num samples is {}".format(str(num)))
assert (num == 10) assert num == 10
if remove_json_files: if remove_json_files:
delete_json_files() delete_json_files()
@ -146,13 +146,13 @@ def test_zip_dataset(remove_json_files=True):
"column_1d", "column_2d", "column_3d", "column_binary"]) "column_1d", "column_2d", "column_3d", "column_binary"])
data3 = ds.zip((data1, data2)) data3 = ds.zip((data1, data2))
ds.serialize(data3, "zip_dataset_pipeline.json") ds.serialize(data3, "zip_dataset_pipeline.json")
assert (validate_jsonfile("zip_dataset_pipeline.json") is True) assert validate_jsonfile("zip_dataset_pipeline.json") is True
assert (validate_jsonfile("zip_dataset_pipeline_typo.json") is False) assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json") data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
ds.serialize(data4, "zip_dataset_pipeline_1.json") ds.serialize(data4, "zip_dataset_pipeline_1.json")
assert (validate_jsonfile("zip_dataset_pipeline_1.json") is True) assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
assert (filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')) assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
rows = 0 rows = 0
for d0, d3, d4 in zip(ds0, data3, data4): for d0, d3, d4 in zip(ds0, data3, data4):
@ -165,7 +165,7 @@ def test_zip_dataset(remove_json_files=True):
assert np.array_equal(t1, d4[offset + num_cols]) assert np.array_equal(t1, d4[offset + num_cols])
offset += 1 offset += 1
rows += 1 rows += 1
assert (rows == 12) assert rows == 12
if remove_json_files: if remove_json_files:
delete_json_files() delete_json_files()
@ -197,7 +197,7 @@ def test_random_crop():
for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(), for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
data2.create_dict_iterator()): data2.create_dict_iterator()):
assert (np.array_equal(item1['image'], item1_1['image'])) assert np.array_equal(item1['image'], item1_1['image'])
image2 = item2["image"] image2 = item2["image"]
@ -250,7 +250,7 @@ def test_minddataset(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME) data = get_data(CV_DIR_NAME)
assert data_set.get_dataset_size() == 5 assert data_set.get_dataset_size() == 5
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 5 assert num_iter == 5

View File

@ -120,7 +120,7 @@ def test_shuffle_05():
def test_shuffle_06(): def test_shuffle_06():
""" """
Test shuffle: with set seed, both datasets Test shuffle: with set seed, both datasets
""" """
logger.info("test_shuffle_06") logger.info("test_shuffle_06")
# define parameters # define parameters

View File

@ -16,7 +16,6 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
@ -36,7 +35,7 @@ def test_tf_skip():
data1 = data1.skip(2) data1 = data1.skip(2)
num_iter = 0 num_iter = 0
for item in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_iter += 1 num_iter += 1
assert num_iter == 1 assert num_iter == 1

View File

@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import time
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
@ -22,7 +21,7 @@ from mindspore import log as logger
def gen(): def gen():
for i in range(100): for i in range(100):
yield np.array(i), yield (np.array(i),)
class Augment: class Augment:
@ -38,7 +37,7 @@ class Augment:
def test_simple_sync_wait(): def test_simple_sync_wait():
""" """
Test simple sync wait: test sync in dataset pipeline Test simple sync wait: test sync in dataset pipeline
""" """
logger.info("test_simple_sync_wait") logger.info("test_simple_sync_wait")
batch_size = 4 batch_size = 4
@ -51,7 +50,7 @@ def test_simple_sync_wait():
count = 0 count = 0
for data in dataset.create_dict_iterator(): for data in dataset.create_dict_iterator():
assert (data["input"][0] == count) assert data["input"][0] == count
count += batch_size count += batch_size
data = {"loss": count} data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data) dataset.sync_update(condition_name="policy", data=data)
@ -59,7 +58,7 @@ def test_simple_sync_wait():
def test_simple_shuffle_sync(): def test_simple_shuffle_sync():
""" """
Test simple shuffle sync: test shuffle before sync Test simple shuffle sync: test shuffle before sync
""" """
logger.info("test_simple_shuffle_sync") logger.info("test_simple_shuffle_sync")
shuffle_size = 4 shuffle_size = 4
@ -83,7 +82,7 @@ def test_simple_shuffle_sync():
def test_two_sync(): def test_two_sync():
""" """
Test two sync: dataset pipeline with with two sync_operators Test two sync: dataset pipeline with with two sync_operators
""" """
logger.info("test_two_sync") logger.info("test_two_sync")
batch_size = 6 batch_size = 6
@ -111,7 +110,7 @@ def test_two_sync():
def test_sync_epoch(): def test_sync_epoch():
""" """
Test sync wait with epochs: test sync with epochs in dataset pipeline Test sync wait with epochs: test sync with epochs in dataset pipeline
""" """
logger.info("test_sync_epoch") logger.info("test_sync_epoch")
batch_size = 30 batch_size = 30
@ -122,11 +121,11 @@ def test_sync_epoch():
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
for epochs in range(3): for _ in range(3):
aug.update({"loss": 0}) aug.update({"loss": 0})
count = 0 count = 0
for data in dataset.create_dict_iterator(): for data in dataset.create_dict_iterator():
assert (data["input"][0] == count) assert data["input"][0] == count
count += batch_size count += batch_size
data = {"loss": count} data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data) dataset.sync_update(condition_name="policy", data=data)
@ -134,7 +133,7 @@ def test_sync_epoch():
def test_multiple_iterators(): def test_multiple_iterators():
""" """
Test sync wait with multiple iterators: will start multiple Test sync wait with multiple iterators: will start multiple
""" """
logger.info("test_sync_epoch") logger.info("test_sync_epoch")
batch_size = 30 batch_size = 30
@ -153,7 +152,7 @@ def test_multiple_iterators():
dataset2 = dataset2.batch(batch_size, drop_remainder=True) dataset2 = dataset2.batch(batch_size, drop_remainder=True)
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()): for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
assert (item1["input"][0] == item2["input"][0]) assert item1["input"][0] == item2["input"][0]
data1 = {"loss": item1["input"][0]} data1 = {"loss": item1["input"][0]}
data2 = {"loss": item2["input"][0]} data2 = {"loss": item2["input"][0]}
dataset.sync_update(condition_name="policy", data=data1) dataset.sync_update(condition_name="policy", data=data1)
@ -162,7 +161,7 @@ def test_multiple_iterators():
def test_sync_exception_01(): def test_sync_exception_01():
""" """
Test sync: with shuffle in sync mode Test sync: with shuffle in sync mode
""" """
logger.info("test_sync_exception_01") logger.info("test_sync_exception_01")
shuffle_size = 4 shuffle_size = 4
@ -183,7 +182,7 @@ def test_sync_exception_01():
def test_sync_exception_02(): def test_sync_exception_02():
""" """
Test sync: with duplicated condition name Test sync: with duplicated condition name
""" """
logger.info("test_sync_exception_02") logger.info("test_sync_exception_02")
batch_size = 6 batch_size = 6

View File

@ -21,13 +21,13 @@ from mindspore import log as logger
# In generator dataset: Number of rows is 3, its value is 0, 1, 2 # In generator dataset: Number of rows is 3, its value is 0, 1, 2
def generator(): def generator():
for i in range(3): for i in range(3):
yield np.array([i]), yield (np.array([i]),)
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10 # In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
def generator_10(): def generator_10():
for i in range(10): for i in range(10):
yield np.array([i]), yield (np.array([i]),)
def filter_func_ge(data): def filter_func_ge(data):
@ -47,8 +47,8 @@ def test_take_01():
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for _, d in enumerate(data1):
assert 0 == d[0][0] assert d[0][0] == 0
assert sum([1 for _ in data1]) == 2 assert sum([1 for _ in data1]) == 2
@ -97,7 +97,7 @@ def test_take_04():
data1 = data1.take(4) data1 = data1.take(4)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i % 3 == d[0][0] assert i % 3 == d[0][0]
@ -113,7 +113,7 @@ def test_take_05():
data1 = data1.take(2) data1 = data1.take(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i == d[0][0] assert i == d[0][0]
@ -130,7 +130,7 @@ def test_take_06():
data1 = data1.repeat(2) data1 = data1.repeat(2)
data1 = data1.take(4) data1 = data1.take(4)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i % 3 == d[0][0] assert i % 3 == d[0][0]
@ -171,7 +171,7 @@ def test_take_09():
data1 = data1.repeat(2) data1 = data1.repeat(2)
data1 = data1.take(-1) data1 = data1.take(-1)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i % 3 == d[0][0] assert i % 3 == d[0][0]
@ -188,7 +188,7 @@ def test_take_10():
data1 = data1.take(-1) data1 = data1.take(-1)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i % 3 == d[0][0] assert i % 3 == d[0][0]
@ -206,7 +206,7 @@ def test_take_11():
data1 = data1.repeat(2) data1 = data1.repeat(2)
data1 = data1.take(-1) data1 = data1.take(-1)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert 2 * (i % 2) == d[0][0] assert 2 * (i % 2) == d[0][0]
@ -224,9 +224,9 @@ def test_take_12():
data1 = data1.batch(2) data1 = data1.batch(2)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for _, d in enumerate(data1):
assert 0 == d[0][0] assert d[0][0] == 0
assert sum([1 for _ in data1]) == 2 assert sum([1 for _ in data1]) == 2
@ -243,9 +243,9 @@ def test_take_13():
data1 = data1.batch(2) data1 = data1.batch(2)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for _, d in enumerate(data1):
assert 2 == d[0][0] assert d[0][0] == 2
assert sum([1 for _ in data1]) == 2 assert sum([1 for _ in data1]) == 2
@ -262,9 +262,9 @@ def test_take_14():
data1 = data1.skip(1) data1 = data1.skip(1)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for _, d in enumerate(data1):
assert 2 == d[0][0] assert d[0][0] == 2
assert sum([1 for _ in data1]) == 2 assert sum([1 for _ in data1]) == 2
@ -279,7 +279,7 @@ def test_take_15():
data1 = data1.take(6) data1 = data1.take(6)
data1 = data1.skip(2) data1 = data1.skip(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert (i + 2) == d[0][0] assert (i + 2) == d[0][0]
@ -296,7 +296,7 @@ def test_take_16():
data1 = data1.skip(3) data1 = data1.skip(3)
data1 = data1.take(5) data1 = data1.take(5)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert (i + 3) == d[0][0] assert (i + 3) == d[0][0]
@ -313,7 +313,7 @@ def test_take_17():
data1 = data1.take(8) data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4) data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i == d[0][0] assert i == d[0][0]
@ -334,9 +334,9 @@ def test_take_18():
data1 = data1.batch(2) data1 = data1.batch(2)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for _, d in enumerate(data1):
assert 2 == d[0][0] assert d[0][0] == 2
assert sum([1 for _ in data1]) == 2 assert sum([1 for _ in data1]) == 2

View File

@ -33,7 +33,7 @@ def test_case_tf_shape():
for data in ds1.create_dict_iterator(): for data in ds1.create_dict_iterator():
logger.info(data) logger.info(data)
output_shape = ds1.output_shapes() output_shape = ds1.output_shapes()
assert (len(output_shape[-1]) == 1) assert len(output_shape[-1]) == 1
def test_case_tf_read_all_dataset(): def test_case_tf_read_all_dataset():
@ -41,7 +41,7 @@ def test_case_tf_read_all_dataset():
ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12 assert ds1.get_dataset_size() == 12
count = 0 count = 0
for data in ds1.create_tuple_iterator(): for _ in ds1.create_tuple_iterator():
count += 1 count += 1
assert count == 12 assert count == 12
@ -51,7 +51,7 @@ def test_case_num_samples():
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8 assert ds1.get_dataset_size() == 8
count = 0 count = 0
for data in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
count += 1 count += 1
assert count == 8 assert count == 8
@ -61,7 +61,7 @@ def test_case_num_samples2():
ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7 assert ds1.get_dataset_size() == 7
count = 0 count = 0
for data in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
count += 1 count += 1
assert count == 7 assert count == 7
@ -70,7 +70,7 @@ def test_case_tf_shape_2():
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2) ds1 = ds1.batch(2)
output_shape = ds1.output_shapes() output_shape = ds1.output_shapes()
assert (len(output_shape[-1]) == 2) assert len(output_shape[-1]) == 2
def test_case_tf_file(): def test_case_tf_file():
@ -175,10 +175,10 @@ def test_tf_record_shard():
assert len(worker1_res) == 48 assert len(worker1_res) == 48
assert len(worker1_res) == len(worker2_res) assert len(worker1_res) == len(worker2_res)
# check criteria 1 # check criteria 1
for i in range(len(worker1_res)): for i, _ in enumerate(worker1_res):
assert (worker1_res[i] != worker2_res[i]) assert worker1_res[i] != worker2_res[i]
# check criteria 2 # check criteria 2
assert (set(worker2_res) == set(worker1_res)) assert set(worker2_res) == set(worker1_res)
def test_tf_shard_equal_rows(): def test_tf_shard_equal_rows():
@ -197,16 +197,16 @@ def test_tf_shard_equal_rows():
worker2_res = get_res(3, 1, 2) worker2_res = get_res(3, 1, 2)
worker3_res = get_res(3, 2, 2) worker3_res = get_res(3, 2, 2)
# check criteria 1 # check criteria 1
for i in range(len(worker1_res)): for i, _ in enumerate(worker1_res):
assert (worker1_res[i] != worker2_res[i]) assert worker1_res[i] != worker2_res[i]
assert (worker2_res[i] != worker3_res[i]) assert worker2_res[i] != worker3_res[i]
# Confirm each worker gets same number of rows # Confirm each worker gets same number of rows
assert len(worker1_res) == 28 assert len(worker1_res) == 28
assert len(worker1_res) == len(worker2_res) assert len(worker1_res) == len(worker2_res)
assert len(worker2_res) == len(worker3_res) assert len(worker2_res) == len(worker3_res)
worker4_res = get_res(1, 0, 1) worker4_res = get_res(1, 0, 1)
assert (len(worker4_res) == 40) assert len(worker4_res) == 40
def test_case_tf_file_no_schema_columns_list(): def test_case_tf_file_no_schema_columns_list():

View File

@ -59,7 +59,7 @@ def test_batch_corner_cases():
# to a pyfunc which makes a deep copy of the row # to a pyfunc which makes a deep copy of the row
def test_variable_size_batch(): def test_variable_size_batch():
def check_res(arr1, arr2): def check_res(arr1, arr2):
for ind in range(len(arr1)): for ind, _ in enumerate(arr1):
if not np.array_equal(arr1[ind], np.array(arr2[ind])): if not np.array_equal(arr1[ind], np.array(arr2[ind])):
return False return False
return len(arr1) == len(arr2) return len(arr1) == len(arr2)
@ -143,7 +143,7 @@ def test_variable_size_batch():
def test_basic_batch_map(): def test_basic_batch_map():
def check_res(arr1, arr2): def check_res(arr1, arr2):
for ind in range(len(arr1)): for ind, _ in enumerate(arr1):
if not np.array_equal(arr1[ind], np.array(arr2[ind])): if not np.array_equal(arr1[ind], np.array(arr2[ind])):
return False return False
return len(arr1) == len(arr2) return len(arr1) == len(arr2)
@ -176,7 +176,7 @@ def test_basic_batch_map():
def test_batch_multi_col_map(): def test_batch_multi_col_map():
def check_res(arr1, arr2): def check_res(arr1, arr2):
for ind in range(len(arr1)): for ind, _ in enumerate(arr1):
if not np.array_equal(arr1[ind], np.array(arr2[ind])): if not np.array_equal(arr1[ind], np.array(arr2[ind])):
return False return False
return len(arr1) == len(arr2) return len(arr1) == len(arr2)
@ -224,7 +224,7 @@ def test_batch_multi_col_map():
def test_var_batch_multi_col_map(): def test_var_batch_multi_col_map():
def check_res(arr1, arr2): def check_res(arr1, arr2):
for ind in range(len(arr1)): for ind, _ in enumerate(arr1):
if not np.array_equal(arr1[ind], np.array(arr2[ind])): if not np.array_equal(arr1[ind], np.array(arr2[ind])):
return False return False
return len(arr1) == len(arr2) return len(arr1) == len(arr2)
@ -269,7 +269,7 @@ def test_var_batch_var_resize():
return ([np.copy(c[0:s, 0:s, :]) for c in col],) return ([np.copy(c[0:s, 0:s, :]) for c in col],)
def add_one(batchInfo): def add_one(batchInfo):
return (batchInfo.get_batch_num() + 1) return batchInfo.get_batch_num() + 1
data1 = ds.ImageFolderDatasetV2("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True) data1 = ds.ImageFolderDatasetV2("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True)
data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize) data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize)
@ -303,7 +303,7 @@ def test_exception():
data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func) data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func)
try: try:
for item in data2.create_dict_iterator(): for _ in data2.create_dict_iterator():
pass pass
assert False assert False
except RuntimeError: except RuntimeError:

View File

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test mindrecord base""" """test mindrecord base"""
import numpy as np
import os import os
import uuid import uuid
import numpy as np
from utils import get_data, get_nlp_data from utils import get_data, get_nlp_data
from mindspore import log as logger from mindspore import log as logger

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""test write multiple images""" """test write multiple images"""
import numpy as np
import os import os
import numpy as np
from utils import get_two_bytes_data, get_multi_bytes_data from utils import get_two_bytes_data, get_multi_bytes_data
from mindspore import log as logger from mindspore import log as logger

View File

@ -14,9 +14,9 @@
"""test mnist to mindrecord tool""" """test mnist to mindrecord tool"""
import gzip import gzip
import os import os
import numpy as np
import cv2 import cv2
import numpy as np
import pytest import pytest
from mindspore import log as logger from mindspore import log as logger