forked from mindspore-Ecosystem/mindspore
10:00 26/5 clean pylint
This commit is contained in:
parent
93fc82b8f7
commit
abca62f407
|
@ -34,7 +34,7 @@ def use_filereader(mindrecord):
|
|||
num_consumer=4,
|
||||
columns=columns_list)
|
||||
num_iter = 0
|
||||
for index, item in enumerate(reader.get_next()):
|
||||
for _, _ in enumerate(reader.get_next()):
|
||||
num_iter += 1
|
||||
print_log(num_iter)
|
||||
end = time.time()
|
||||
|
@ -48,7 +48,7 @@ def use_minddataset(mindrecord):
|
|||
columns_list=columns_list,
|
||||
num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
print_log(num_iter)
|
||||
end = time.time()
|
||||
|
@ -64,7 +64,7 @@ def use_tfrecorddataset(tfrecord):
|
|||
shuffle=ds.Shuffle.GLOBAL)
|
||||
data_set = data_set.shuffle(10000)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
print_log(num_iter)
|
||||
end = time.time()
|
||||
|
@ -87,7 +87,7 @@ def use_tensorflow_tfrecorddataset(tfrecord):
|
|||
num_parallel_reads=4)
|
||||
data_set = data_set.map(_parse_record, num_parallel_calls=4)
|
||||
num_iter = 0
|
||||
for item in data_set.__iter__():
|
||||
for _ in data_set.__iter__():
|
||||
num_iter += 1
|
||||
print_log(num_iter)
|
||||
end = time.time()
|
||||
|
@ -96,18 +96,18 @@ def use_tensorflow_tfrecorddataset(tfrecord):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# use MindDataset
|
||||
mindrecord = './imagenet.mindrecord00'
|
||||
use_minddataset(mindrecord)
|
||||
mindrecord_test = './imagenet.mindrecord00'
|
||||
use_minddataset(mindrecord_test)
|
||||
|
||||
# use TFRecordDataset
|
||||
tfrecord = ['imagenet.tfrecord00', 'imagenet.tfrecord01', 'imagenet.tfrecord02', 'imagenet.tfrecord03',
|
||||
'imagenet.tfrecord04', 'imagenet.tfrecord05', 'imagenet.tfrecord06', 'imagenet.tfrecord07',
|
||||
'imagenet.tfrecord08', 'imagenet.tfrecord09', 'imagenet.tfrecord10', 'imagenet.tfrecord11',
|
||||
'imagenet.tfrecord12', 'imagenet.tfrecord13', 'imagenet.tfrecord14', 'imagenet.tfrecord15']
|
||||
use_tfrecorddataset(tfrecord)
|
||||
tfrecord_test = ['imagenet.tfrecord00', 'imagenet.tfrecord01', 'imagenet.tfrecord02', 'imagenet.tfrecord03',
|
||||
'imagenet.tfrecord04', 'imagenet.tfrecord05', 'imagenet.tfrecord06', 'imagenet.tfrecord07',
|
||||
'imagenet.tfrecord08', 'imagenet.tfrecord09', 'imagenet.tfrecord10', 'imagenet.tfrecord11',
|
||||
'imagenet.tfrecord12', 'imagenet.tfrecord13', 'imagenet.tfrecord14', 'imagenet.tfrecord15']
|
||||
use_tfrecorddataset(tfrecord_test)
|
||||
|
||||
# use TensorFlow TFRecordDataset
|
||||
use_tensorflow_tfrecorddataset(tfrecord)
|
||||
use_tensorflow_tfrecorddataset(tfrecord_test)
|
||||
|
||||
# use FileReader
|
||||
# use_filereader(mindrecord)
|
||||
|
|
|
@ -29,7 +29,7 @@ def test_case_0():
|
|||
# apply dataset operations
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
ds1 = ds1.map(input_column_names=col, output_column_names="out", operation=(lambda x: x + x))
|
||||
ds1 = ds1.map(input_columns=col, output_columns="out", operations=(lambda x: x + x))
|
||||
|
||||
print("************** Output Tensor *****************")
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
|
@ -49,7 +49,7 @@ def test_case_1():
|
|||
# apply dataset operations
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
ds1 = ds1.map(input_column_names=col, output_column_names=["out0", "out1"], operation=(lambda x: (x, x + x)))
|
||||
ds1 = ds1.map(input_columns=col, output_columns=["out0", "out1"], operations=(lambda x: (x, x + x)))
|
||||
|
||||
print("************** Output Tensor *****************")
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
|
@ -72,7 +72,7 @@ def test_case_2():
|
|||
# apply dataset operations
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
ds1 = ds1.map(input_column_names=col, output_column_names="out", operation=(lambda x, y: x + y))
|
||||
ds1 = ds1.map(input_columns=col, output_columns="out", operations=(lambda x, y: x + y))
|
||||
|
||||
print("************** Output Tensor *****************")
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
|
@ -93,8 +93,8 @@ def test_case_3():
|
|||
# apply dataset operations
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
ds1 = ds1.map(input_column_names=col, output_column_names=["out0", "out1", "out2"],
|
||||
operation=(lambda x, y: (x, x + y, x + x + y)))
|
||||
ds1 = ds1.map(input_columns=col, output_columns=["out0", "out1", "out2"],
|
||||
operations=(lambda x, y: (x, x + y, x + x + y)))
|
||||
|
||||
print("************** Output Tensor *****************")
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
|
@ -119,8 +119,8 @@ def test_case_4():
|
|||
# apply dataset operations
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
ds1 = ds1.map(input_column_names=col, output_column_names=["out0", "out1", "out2"], num_parallel_workers=4,
|
||||
operation=(lambda x, y: (x, x + y, x + x + y)))
|
||||
ds1 = ds1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
|
||||
operations=(lambda x, y: (x, x + y, x + x + y)))
|
||||
|
||||
print("************** Output Tensor *****************")
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
|
|
|
@ -22,11 +22,11 @@ def create_data_cache_dir():
|
|||
cwd = os.getcwd()
|
||||
target_directory = os.path.join(cwd, "data_cache")
|
||||
try:
|
||||
if not (os.path.exists(target_directory)):
|
||||
if not os.path.exists(target_directory):
|
||||
os.mkdir(target_directory)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % target_directory)
|
||||
return target_directory;
|
||||
return target_directory
|
||||
|
||||
|
||||
def download_and_uncompress(files, source_url, target_directory, is_tar=False):
|
||||
|
@ -53,13 +53,13 @@ def download_and_uncompress(files, source_url, target_directory, is_tar=False):
|
|||
|
||||
|
||||
def download_mnist(target_directory=None):
|
||||
if target_directory == None:
|
||||
if target_directory is None:
|
||||
target_directory = create_data_cache_dir()
|
||||
|
||||
##create mnst directory
|
||||
target_directory = os.path.join(target_directory, "mnist")
|
||||
try:
|
||||
if not (os.path.exists(target_directory)):
|
||||
if not os.path.exists(target_directory):
|
||||
os.mkdir(target_directory)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % target_directory)
|
||||
|
@ -78,14 +78,14 @@ CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
|
|||
|
||||
|
||||
def download_cifar(target_directory, files, directory_from_tar):
|
||||
if target_directory == None:
|
||||
if target_directory is None:
|
||||
target_directory = create_data_cache_dir()
|
||||
|
||||
download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True)
|
||||
|
||||
##if target dir was specify move data from directory created by tar
|
||||
##and put data into target dir
|
||||
if target_directory != None:
|
||||
if target_directory is not None:
|
||||
tar_dir_full_path = os.path.join(target_directory, directory_from_tar)
|
||||
all_files = os.path.join(tar_dir_full_path, "*")
|
||||
cmd = "mv " + all_files + " " + target_directory
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
|
||||
def test_shape():
|
||||
x = [2, 3]
|
||||
|
|
|
@ -221,7 +221,7 @@ def test_apply_exception_case():
|
|||
try:
|
||||
data2 = data1.apply(dataset_fn)
|
||||
data3 = data1.apply(dataset_fn)
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
pass
|
||||
assert False
|
||||
except ValueError:
|
||||
|
|
|
@ -35,10 +35,10 @@ def test_case_dataset_cifar10():
|
|||
data1 = ds.Cifar10Dataset(DATA_DIR_10, 100)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
num_iter += 1
|
||||
assert (num_iter == 100)
|
||||
assert num_iter == 100
|
||||
|
||||
|
||||
def test_case_dataset_cifar100():
|
||||
|
@ -50,10 +50,10 @@ def test_case_dataset_cifar100():
|
|||
data1 = ds.Cifar100Dataset(DATA_DIR_100, 100)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
num_iter += 1
|
||||
assert (num_iter == 100)
|
||||
assert num_iter == 100
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
"""
|
||||
Testing configuration manager
|
||||
"""
|
||||
import os
|
||||
import filecmp
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
@ -89,7 +89,7 @@ def test_pipeline():
|
|||
ds.serialize(data2, "testpipeline2.json")
|
||||
|
||||
# check that the generated output is different
|
||||
assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json'))
|
||||
assert filecmp.cmp('testpipeline.json', 'testpipeline2.json')
|
||||
|
||||
# this test passes currently because our num_parallel_workers don't get updated.
|
||||
|
||||
|
|
|
@ -33,9 +33,9 @@ def test_celeba_dataset_label():
|
|||
logger.info("----------attr--------")
|
||||
logger.info(item["attr"])
|
||||
for index in range(len(expect_labels[count])):
|
||||
assert (item["attr"][index] == expect_labels[count][index])
|
||||
assert item["attr"][index] == expect_labels[count][index]
|
||||
count = count + 1
|
||||
assert (count == 2)
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_celeba_dataset_op():
|
||||
|
@ -54,7 +54,7 @@ def test_celeba_dataset_op():
|
|||
logger.info("----------image--------")
|
||||
logger.info(item["image"])
|
||||
count = count + 1
|
||||
assert (count == 4)
|
||||
assert count == 4
|
||||
|
||||
|
||||
def test_celeba_dataset_ext():
|
||||
|
@ -69,9 +69,9 @@ def test_celeba_dataset_ext():
|
|||
logger.info("----------attr--------")
|
||||
logger.info(item["attr"])
|
||||
for index in range(len(expect_labels[count])):
|
||||
assert (item["attr"][index] == expect_labels[count][index])
|
||||
assert item["attr"][index] == expect_labels[count][index]
|
||||
count = count + 1
|
||||
assert (count == 1)
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_celeba_dataset_distribute():
|
||||
|
@ -83,7 +83,7 @@ def test_celeba_dataset_distribute():
|
|||
logger.info("----------attr--------")
|
||||
logger.info(item["attr"])
|
||||
count = count + 1
|
||||
assert (count == 1)
|
||||
assert count == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -35,7 +35,7 @@ def test_imagefolder_basic():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 44)
|
||||
assert num_iter == 44
|
||||
|
||||
|
||||
def test_imagefolder_numsamples():
|
||||
|
@ -55,7 +55,7 @@ def test_imagefolder_numsamples():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 10)
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_imagefolder_numshards():
|
||||
|
@ -75,7 +75,7 @@ def test_imagefolder_numshards():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 11)
|
||||
assert num_iter == 11
|
||||
|
||||
|
||||
def test_imagefolder_shardid():
|
||||
|
@ -95,7 +95,7 @@ def test_imagefolder_shardid():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 11)
|
||||
assert num_iter == 11
|
||||
|
||||
|
||||
def test_imagefolder_noshuffle():
|
||||
|
@ -115,7 +115,7 @@ def test_imagefolder_noshuffle():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 44)
|
||||
assert num_iter == 44
|
||||
|
||||
|
||||
def test_imagefolder_extrashuffle():
|
||||
|
@ -136,7 +136,7 @@ def test_imagefolder_extrashuffle():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 88)
|
||||
assert num_iter == 88
|
||||
|
||||
|
||||
def test_imagefolder_classindex():
|
||||
|
@ -157,11 +157,11 @@ def test_imagefolder_classindex():
|
|||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
assert (item["label"] == golden[num_iter])
|
||||
assert item["label"] == golden[num_iter]
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 22)
|
||||
assert num_iter == 22
|
||||
|
||||
|
||||
def test_imagefolder_negative_classindex():
|
||||
|
@ -182,11 +182,11 @@ def test_imagefolder_negative_classindex():
|
|||
# in this example, each dictionary has keys "image" and "label"
|
||||
logger.info("image is {}".format(item["image"]))
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
assert (item["label"] == golden[num_iter])
|
||||
assert item["label"] == golden[num_iter]
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 22)
|
||||
assert num_iter == 22
|
||||
|
||||
|
||||
def test_imagefolder_extensions():
|
||||
|
@ -207,7 +207,7 @@ def test_imagefolder_extensions():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 44)
|
||||
assert num_iter == 44
|
||||
|
||||
|
||||
def test_imagefolder_decode():
|
||||
|
@ -228,7 +228,7 @@ def test_imagefolder_decode():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 44)
|
||||
assert num_iter == 44
|
||||
|
||||
|
||||
def test_sequential_sampler():
|
||||
|
@ -255,7 +255,7 @@ def test_sequential_sampler():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Result: {}".format(result))
|
||||
assert (result == golden)
|
||||
assert result == golden
|
||||
|
||||
|
||||
def test_random_sampler():
|
||||
|
@ -276,7 +276,7 @@ def test_random_sampler():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 44)
|
||||
assert num_iter == 44
|
||||
|
||||
|
||||
def test_distributed_sampler():
|
||||
|
@ -297,7 +297,7 @@ def test_distributed_sampler():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 5)
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
def test_pk_sampler():
|
||||
|
@ -318,7 +318,7 @@ def test_pk_sampler():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 12)
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_subset_random_sampler():
|
||||
|
@ -340,7 +340,7 @@ def test_subset_random_sampler():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 12)
|
||||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_weighted_random_sampler():
|
||||
|
@ -362,7 +362,7 @@ def test_weighted_random_sampler():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 11)
|
||||
assert num_iter == 11
|
||||
|
||||
|
||||
def test_imagefolder_rename():
|
||||
|
@ -382,7 +382,7 @@ def test_imagefolder_rename():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 10)
|
||||
assert num_iter == 10
|
||||
|
||||
data1 = data1.rename(input_columns=["image"], output_columns="image2")
|
||||
|
||||
|
@ -394,7 +394,7 @@ def test_imagefolder_rename():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 10)
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_imagefolder_zip():
|
||||
|
@ -419,7 +419,7 @@ def test_imagefolder_zip():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert (num_iter == 10)
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -30,7 +28,7 @@ def test_tf_file_normal():
|
|||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
data1 = data1.repeat(1)
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
for _ in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
|
|
|
@ -16,7 +16,6 @@ import numpy as np
|
|||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_FILE = "../data/dataset/testManifestData/test.manifest"
|
||||
|
@ -34,9 +33,9 @@ def test_manifest_dataset_train():
|
|||
cat_count = cat_count + 1
|
||||
elif item["label"].size == 1 and item["label"] == 1:
|
||||
dog_count = dog_count + 1
|
||||
assert (cat_count == 2)
|
||||
assert (dog_count == 1)
|
||||
assert (count == 4)
|
||||
assert cat_count == 2
|
||||
assert dog_count == 1
|
||||
assert count == 4
|
||||
|
||||
|
||||
def test_manifest_dataset_eval():
|
||||
|
@ -46,36 +45,36 @@ def test_manifest_dataset_eval():
|
|||
logger.info("item[image] is {}".format(item["image"]))
|
||||
count = count + 1
|
||||
if item["label"] != 0 and item["label"] != 1:
|
||||
assert (0)
|
||||
assert (count == 2)
|
||||
assert 0
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_manifest_dataset_class_index():
|
||||
class_indexing = {"dog": 11}
|
||||
data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing)
|
||||
out_class_indexing = data.get_class_indexing()
|
||||
assert (out_class_indexing == {"dog": 11})
|
||||
assert out_class_indexing == {"dog": 11}
|
||||
count = 0
|
||||
for item in data.create_dict_iterator():
|
||||
logger.info("item[image] is {}".format(item["image"]))
|
||||
count = count + 1
|
||||
if item["label"] != 11:
|
||||
assert (0)
|
||||
assert (count == 1)
|
||||
assert 0
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_manifest_dataset_get_class_index():
|
||||
data = ds.ManifestDataset(DATA_FILE, decode=True)
|
||||
class_indexing = data.get_class_indexing()
|
||||
assert (class_indexing == {'cat': 0, 'dog': 1, 'flower': 2})
|
||||
assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
|
||||
data = data.shuffle(4)
|
||||
class_indexing = data.get_class_indexing()
|
||||
assert (class_indexing == {'cat': 0, 'dog': 1, 'flower': 2})
|
||||
assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
|
||||
count = 0
|
||||
for item in data.create_dict_iterator():
|
||||
logger.info("item[image] is {}".format(item["image"]))
|
||||
count = count + 1
|
||||
assert (count == 4)
|
||||
assert count == 4
|
||||
|
||||
|
||||
def test_manifest_dataset_multi_label():
|
||||
|
@ -83,10 +82,10 @@ def test_manifest_dataset_multi_label():
|
|||
count = 0
|
||||
expect_label = [1, 0, 0, [0, 2]]
|
||||
for item in data.create_dict_iterator():
|
||||
assert (item["label"].tolist() == expect_label[count])
|
||||
assert item["label"].tolist() == expect_label[count]
|
||||
logger.info("item[image] is {}".format(item["image"]))
|
||||
count = count + 1
|
||||
assert (count == 4)
|
||||
assert count == 4
|
||||
|
||||
|
||||
def multi_label_hot(x):
|
||||
|
@ -109,7 +108,7 @@ def test_manifest_dataset_multi_label_onehot():
|
|||
data = data.batch(2)
|
||||
count = 0
|
||||
for item in data.create_dict_iterator():
|
||||
assert (item["label"].tolist() == expect_label[count])
|
||||
assert item["label"].tolist() == expect_label[count]
|
||||
logger.info("item[image] is {}".format(item["image"]))
|
||||
count = count + 1
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_imagefolder_shardings(print_res=False):
|
|||
res = []
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
res.append(item["label"].item())
|
||||
if (print_res):
|
||||
if print_res:
|
||||
logger.info("labels of dataset: {}".format(res))
|
||||
return res
|
||||
|
||||
|
@ -39,12 +39,12 @@ def test_imagefolder_shardings(print_res=False):
|
|||
assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows
|
||||
assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows
|
||||
# total 22 in dataset rows because of class indexing which takes only 2 folders
|
||||
assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6)
|
||||
assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3)
|
||||
assert len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6
|
||||
assert len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3
|
||||
# test with repeat
|
||||
assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
|
||||
assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
|
||||
assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20)
|
||||
assert len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20
|
||||
|
||||
|
||||
def test_tfrecord_shardings1(print_res=False):
|
||||
|
@ -176,8 +176,8 @@ def test_voc_shardings(print_res=False):
|
|||
# then takes the first 2 bc num_samples = 2
|
||||
assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4)
|
||||
# test that each epoch, each shard_worker returns a different sample
|
||||
assert (len(sharding_config(2, 0, None, True, 1)) == 5)
|
||||
assert (len(set(sharding_config(11, 0, None, True, 10))) > 1)
|
||||
assert len(sharding_config(2, 0, None, True, 1)) == 5
|
||||
assert len(set(sharding_config(11, 0, None, True, 10))) > 1
|
||||
|
||||
|
||||
def test_cifar10_shardings(print_res=False):
|
||||
|
@ -196,8 +196,8 @@ def test_cifar10_shardings(print_res=False):
|
|||
|
||||
# 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow
|
||||
# therefore, only 2 test cases for now.
|
||||
assert (sharding_config(10000, 9999, 7, False, 1) == [9])
|
||||
assert (sharding_config(10000, 0, 4, False, 3) == [0, 0, 0])
|
||||
assert sharding_config(10000, 9999, 7, False, 1) == [9]
|
||||
assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0]
|
||||
|
||||
|
||||
def test_cifar100_shardings(print_res=False):
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_textline_dataset_one_file():
|
|||
for i in data.create_dict_iterator():
|
||||
logger.info("{}".format(i["text"]))
|
||||
count += 1
|
||||
assert (count == 3)
|
||||
assert count == 3
|
||||
|
||||
|
||||
def test_textline_dataset_all_file():
|
||||
|
@ -36,7 +36,7 @@ def test_textline_dataset_all_file():
|
|||
for i in data.create_dict_iterator():
|
||||
logger.info("{}".format(i["text"]))
|
||||
count += 1
|
||||
assert (count == 5)
|
||||
assert count == 5
|
||||
|
||||
|
||||
def test_textline_dataset_totext():
|
||||
|
@ -46,8 +46,8 @@ def test_textline_dataset_totext():
|
|||
line = ["This is a text file.", "Another file.",
|
||||
"Be happy every day.", "End of file.", "Good luck to everyone."]
|
||||
for i in data.create_dict_iterator():
|
||||
str = i["text"].item().decode("utf8")
|
||||
assert (str == line[count])
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert (count == 5)
|
||||
# Restore configuration num_parallel_workers
|
||||
|
@ -57,17 +57,17 @@ def test_textline_dataset_totext():
|
|||
def test_textline_dataset_num_samples():
|
||||
data = ds.TextFileDataset(DATA_FILE, num_samples=2)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator():
|
||||
count += 1
|
||||
assert (count == 2)
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_textline_dataset_distribution():
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator():
|
||||
count += 1
|
||||
assert (count == 3)
|
||||
assert count == 3
|
||||
|
||||
|
||||
def test_textline_dataset_repeat():
|
||||
|
@ -78,16 +78,16 @@ def test_textline_dataset_repeat():
|
|||
"This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"This is a text file.", "Be happy every day.", "Good luck to everyone."]
|
||||
for i in data.create_dict_iterator():
|
||||
str = i["text"].item().decode("utf8")
|
||||
assert (str == line[count])
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert (count == 9)
|
||||
assert count == 9
|
||||
|
||||
|
||||
def test_textline_dataset_get_datasetsize():
|
||||
data = ds.TextFileDataset(DATA_FILE)
|
||||
size = data.get_dataset_size()
|
||||
assert (size == 3)
|
||||
assert size == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -15,9 +15,8 @@
|
|||
"""
|
||||
Testing Decode op in DE
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from util import diff_mse
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
|
|
@ -88,7 +88,7 @@ def test_filter_by_generator_with_repeat():
|
|||
ret_data.append(item["data"])
|
||||
assert num_iter == 44
|
||||
for i in range(4):
|
||||
for ii in range(len(expected_rs)):
|
||||
for ii, _ in enumerate(expected_rs):
|
||||
index = i * len(expected_rs) + ii
|
||||
assert ret_data[index] == expected_rs[ii]
|
||||
|
||||
|
@ -106,7 +106,7 @@ def test_filter_by_generator_with_repeat_after():
|
|||
ret_data.append(item["data"])
|
||||
assert num_iter == 44
|
||||
for i in range(4):
|
||||
for ii in range(len(expected_rs)):
|
||||
for ii, _ in enumerate(expected_rs):
|
||||
index = i * len(expected_rs) + ii
|
||||
assert ret_data[index] == expected_rs[ii]
|
||||
|
||||
|
@ -167,7 +167,7 @@ def test_filter_by_generator_with_shuffle():
|
|||
dataset_s = dataset.shuffle(4)
|
||||
dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
for _ in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 21
|
||||
|
||||
|
@ -184,7 +184,7 @@ def test_filter_by_generator_with_shuffle_after():
|
|||
dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4)
|
||||
dataset_s = dataset_f.shuffle(4)
|
||||
num_iter = 0
|
||||
for item in dataset_s.create_dict_iterator():
|
||||
for _ in dataset_s.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 21
|
||||
|
||||
|
@ -258,8 +258,7 @@ def filter_func_map(col1, col2):
|
|||
def filter_func_map_part(col1):
|
||||
if col1 < 3:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def filter_func_map_all(col1, col2):
|
||||
|
@ -276,7 +275,7 @@ def func_map(data_col1, data_col2):
|
|||
|
||||
|
||||
def func_map_part(data_col1):
|
||||
return (data_col1)
|
||||
return data_col1
|
||||
|
||||
|
||||
# test with map
|
||||
|
@ -473,7 +472,6 @@ def test_filte_case_dataset_cifar10():
|
|||
ds.config.load('../data/dataset/declient_filter.cfg')
|
||||
dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False)
|
||||
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
for item in dataset_f1.create_dict_iterator():
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
assert item["label"] % 3 == 0
|
||||
|
|
|
@ -184,7 +184,7 @@ def test_case_6():
|
|||
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
|
||||
mstype.uint64, mstype.float32, mstype.float64]
|
||||
|
||||
for i in range(len(np_types)):
|
||||
for i, _ in enumerate(np_types):
|
||||
type_tester_with_type_check(np_types[i], de_types[i])
|
||||
|
||||
|
||||
|
@ -219,7 +219,7 @@ def test_case_7():
|
|||
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
|
||||
mstype.uint64, mstype.float32, mstype.float64]
|
||||
|
||||
for i in range(len(np_types)):
|
||||
for i, _ in enumerate(np_types):
|
||||
type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
|
||||
|
||||
|
||||
|
@ -526,7 +526,7 @@ def test_sequential_sampler():
|
|||
def test_random_sampler():
|
||||
source = [(np.array([x]),) for x in range(64)]
|
||||
ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
for _ in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
pass
|
||||
|
||||
|
||||
|
@ -611,7 +611,7 @@ def test_schema():
|
|||
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
|
||||
mstype.uint64, mstype.float32, mstype.float64]
|
||||
|
||||
for i in range(len(np_types)):
|
||||
for i, _ in enumerate(np_types):
|
||||
type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
|
||||
|
||||
|
||||
|
@ -630,8 +630,7 @@ def manual_test_keyborad_interrupt():
|
|||
return 1024
|
||||
|
||||
ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
|
||||
i = 0
|
||||
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
for _ in ds1.create_dict_iterator(): # each data is a dictionary
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import copy
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
|
|
@ -320,7 +320,7 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file):
|
|||
data = data.shuffle(2)
|
||||
data = data.repeat(9)
|
||||
num_iter = 0
|
||||
for item in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 18
|
||||
|
||||
|
@ -572,7 +572,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
|
|||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
for epoch in range(5):
|
||||
for _ in range(5):
|
||||
num_iter = 0
|
||||
for data in data_set:
|
||||
logger.info("data is {}".format(data))
|
||||
|
@ -603,7 +603,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_
|
|||
|
||||
data_set = data_set.batch(2)
|
||||
assert data_set.get_dataset_size() == 5
|
||||
for epoch in range(5):
|
||||
for _ in range(5):
|
||||
num_iter = 0
|
||||
for data in data_set:
|
||||
logger.info("data is {}".format(data))
|
||||
|
|
|
@ -91,7 +91,7 @@ def test_invalid_mindrecord():
|
|||
with pytest.raises(Exception, match="MindRecordOp init failed"):
|
||||
data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 0
|
||||
os.remove('dummy.mindrecord')
|
||||
|
@ -105,7 +105,7 @@ def test_minddataset_lack_db():
|
|||
with pytest.raises(Exception, match="MindRecordOp init failed"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 0
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
@ -119,7 +119,7 @@ def test_cv_minddataset_pk_sample_error_class_column():
|
|||
with pytest.raises(Exception, match="MindRecordOp launch failed"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
"""
|
||||
This is the test module for mindrecord
|
||||
"""
|
||||
import numpy as np
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
|
|
@ -15,16 +15,10 @@
|
|||
"""
|
||||
This is the test module for mindrecord
|
||||
"""
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import re
|
||||
import string
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
from mindspore.dataset.text import to_str
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_one_hot_op():
|
|||
label = data["label"]
|
||||
logger.info("label is {}".format(label))
|
||||
logger.info("golden_label is {}".format(golden_label))
|
||||
assert (label.all() == golden_label.all())
|
||||
assert label.all() == golden_label.all()
|
||||
logger.info("====test one hot op ok====")
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
@ -50,6 +49,7 @@ def get_normalized(image_id):
|
|||
if num_iter == image_id:
|
||||
return normalize_np(image)
|
||||
num_iter += 1
|
||||
return None
|
||||
|
||||
|
||||
def test_normalize_op():
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"""
|
||||
Testing Pad op in DE
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from util import diff_mse
|
||||
|
||||
|
@ -118,7 +117,7 @@ def test_pad_grayscale():
|
|||
for shape1, shape2 in zip(dataset_shape_1, dataset_shape_2):
|
||||
# validate that the first two dimensions are the same
|
||||
# we have a little inconsistency here because the third dimension is 1 after py_vision.Grayscale
|
||||
assert (shape1[0:1] == shape2[0:1])
|
||||
assert shape1[0:1] == shape2[0:1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
@ -117,8 +117,7 @@ def batch_padding_performance_3d():
|
|||
data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
ret = []
|
||||
for data in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
@ -134,7 +133,7 @@ def batch_padding_performance_1d():
|
|||
data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
for data in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
@ -150,7 +149,7 @@ def batch_pyfunc_padding_3d():
|
|||
data1 = data1.batch(batch_size=24, drop_remainder=True)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
for data in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
@ -165,7 +164,7 @@ def batch_pyfunc_padding_1d():
|
|||
data1 = data1.batch(batch_size=24, drop_remainder=True)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
for data in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
@ -197,7 +196,7 @@ def test_pad_via_map():
|
|||
res_from_map = pad_map_config()
|
||||
res_from_batch = pad_batch_config()
|
||||
assert len(res_from_batch) == len(res_from_batch)
|
||||
for i in range(len(res_from_map)):
|
||||
for i, _ in enumerate(res_from_map):
|
||||
assert np.array_equal(res_from_map[i], res_from_batch[i])
|
||||
|
||||
|
||||
|
|
|
@ -15,8 +15,9 @@
|
|||
"""
|
||||
Testing RandomCropAndResize op in DE
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
import mindspore.dataset.transforms.vision.py_transforms as py_vision
|
||||
import mindspore.dataset.transforms.vision.utils as mode
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
"""
|
||||
Testing RandomCropDecodeResize op in DE
|
||||
"""
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from pathlib import Path
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
@ -39,7 +37,7 @@ def test_randomdataset_basic1():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: ", num_iter)
|
||||
assert (num_iter == 200)
|
||||
assert num_iter == 200
|
||||
|
||||
|
||||
# Another simple test
|
||||
|
@ -65,7 +63,7 @@ def test_randomdataset_basic2():
|
|||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: ", num_iter)
|
||||
assert (num_iter == 40)
|
||||
assert num_iter == 40
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
"""
|
||||
Testing RandomRotation op in DE
|
||||
"""
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_rename():
|
|||
|
||||
num_iter = 0
|
||||
|
||||
for i, item in enumerate(data.create_dict_iterator()):
|
||||
for _, item in enumerate(data.create_dict_iterator()):
|
||||
logger.info("item[mask] is {}".format(item["masks"]))
|
||||
np.testing.assert_equal(item["masks"], item["input_ids"])
|
||||
logger.info("item[seg_ids] is {}".format(item["seg_ids"]))
|
||||
|
|
|
@ -159,7 +159,7 @@ def test_rgb_hsv_pipeline():
|
|||
ori_img = data1["image"]
|
||||
cvt_img = data2["image"]
|
||||
assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0)
|
||||
assert (ori_img.shape == cvt_img.shape)
|
||||
assert ori_img.shape == cvt_img.shape
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_imagefolder(remove_json_files=True):
|
|||
# data1 should still work after saving.
|
||||
ds.serialize(data1, "imagenet_dataset_pipeline.json")
|
||||
ds1_dict = ds.serialize(data1)
|
||||
assert (validate_jsonfile("imagenet_dataset_pipeline.json") is True)
|
||||
assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
|
||||
|
||||
# Print the serialized pipeline to stdout
|
||||
ds.show(data1)
|
||||
|
@ -68,8 +68,8 @@ def test_imagefolder(remove_json_files=True):
|
|||
# Serialize the pipeline we just deserialized.
|
||||
# The content of the json file should be the same to the previous serialize.
|
||||
ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
|
||||
assert (validate_jsonfile("imagenet_dataset_pipeline_1.json") is True)
|
||||
assert (filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json'))
|
||||
assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
|
||||
assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
|
||||
|
||||
# Deserialize the latest json file again
|
||||
data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
|
||||
|
@ -78,16 +78,16 @@ def test_imagefolder(remove_json_files=True):
|
|||
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||
for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
|
||||
data3.create_dict_iterator(), data4.create_dict_iterator()):
|
||||
assert (np.array_equal(item1['image'], item2['image']))
|
||||
assert (np.array_equal(item1['image'], item3['image']))
|
||||
assert (np.array_equal(item1['label'], item2['label']))
|
||||
assert (np.array_equal(item1['label'], item3['label']))
|
||||
assert (np.array_equal(item3['image'], item4['image']))
|
||||
assert (np.array_equal(item3['label'], item4['label']))
|
||||
assert np.array_equal(item1['image'], item2['image'])
|
||||
assert np.array_equal(item1['image'], item3['image'])
|
||||
assert np.array_equal(item1['label'], item2['label'])
|
||||
assert np.array_equal(item1['label'], item3['label'])
|
||||
assert np.array_equal(item3['image'], item4['image'])
|
||||
assert np.array_equal(item3['label'], item4['label'])
|
||||
num_samples += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_samples))
|
||||
assert (num_samples == 6)
|
||||
assert num_samples == 6
|
||||
|
||||
# Remove the generated json file
|
||||
if remove_json_files:
|
||||
|
@ -106,26 +106,26 @@ def test_mnist_dataset(remove_json_files=True):
|
|||
data1 = data1.batch(batch_size=10, drop_remainder=True)
|
||||
|
||||
ds.serialize(data1, "mnist_dataset_pipeline.json")
|
||||
assert (validate_jsonfile("mnist_dataset_pipeline.json") is True)
|
||||
assert validate_jsonfile("mnist_dataset_pipeline.json") is True
|
||||
|
||||
data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
|
||||
ds.serialize(data2, "mnist_dataset_pipeline_1.json")
|
||||
assert (validate_jsonfile("mnist_dataset_pipeline_1.json") is True)
|
||||
assert (filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json'))
|
||||
assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
|
||||
assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
|
||||
|
||||
data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
|
||||
|
||||
num = 0
|
||||
for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
|
||||
data3.create_dict_iterator()):
|
||||
assert (np.array_equal(data1['image'], data2['image']))
|
||||
assert (np.array_equal(data1['image'], data3['image']))
|
||||
assert (np.array_equal(data1['label'], data2['label']))
|
||||
assert (np.array_equal(data1['label'], data3['label']))
|
||||
assert np.array_equal(data1['image'], data2['image'])
|
||||
assert np.array_equal(data1['image'], data3['image'])
|
||||
assert np.array_equal(data1['label'], data2['label'])
|
||||
assert np.array_equal(data1['label'], data3['label'])
|
||||
num += 1
|
||||
|
||||
logger.info("mnist total num samples is {}".format(str(num)))
|
||||
assert (num == 10)
|
||||
assert num == 10
|
||||
|
||||
if remove_json_files:
|
||||
delete_json_files()
|
||||
|
@ -146,13 +146,13 @@ def test_zip_dataset(remove_json_files=True):
|
|||
"column_1d", "column_2d", "column_3d", "column_binary"])
|
||||
data3 = ds.zip((data1, data2))
|
||||
ds.serialize(data3, "zip_dataset_pipeline.json")
|
||||
assert (validate_jsonfile("zip_dataset_pipeline.json") is True)
|
||||
assert (validate_jsonfile("zip_dataset_pipeline_typo.json") is False)
|
||||
assert validate_jsonfile("zip_dataset_pipeline.json") is True
|
||||
assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
|
||||
|
||||
data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
|
||||
ds.serialize(data4, "zip_dataset_pipeline_1.json")
|
||||
assert (validate_jsonfile("zip_dataset_pipeline_1.json") is True)
|
||||
assert (filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json'))
|
||||
assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
|
||||
assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
|
||||
|
||||
rows = 0
|
||||
for d0, d3, d4 in zip(ds0, data3, data4):
|
||||
|
@ -165,7 +165,7 @@ def test_zip_dataset(remove_json_files=True):
|
|||
assert np.array_equal(t1, d4[offset + num_cols])
|
||||
offset += 1
|
||||
rows += 1
|
||||
assert (rows == 12)
|
||||
assert rows == 12
|
||||
|
||||
if remove_json_files:
|
||||
delete_json_files()
|
||||
|
@ -197,7 +197,7 @@ def test_random_crop():
|
|||
|
||||
for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
|
||||
data2.create_dict_iterator()):
|
||||
assert (np.array_equal(item1['image'], item1_1['image']))
|
||||
assert np.array_equal(item1['image'], item1_1['image'])
|
||||
image2 = item2["image"]
|
||||
|
||||
|
||||
|
@ -250,7 +250,7 @@ def test_minddataset(add_and_remove_cv_file):
|
|||
data = get_data(CV_DIR_NAME)
|
||||
assert data_set.get_dataset_size() == 5
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ def test_shuffle_05():
|
|||
|
||||
def test_shuffle_06():
|
||||
"""
|
||||
Test shuffle: with set seed, both datasets
|
||||
Test shuffle: with set seed, both datasets
|
||||
"""
|
||||
logger.info("test_shuffle_06")
|
||||
# define parameters
|
||||
|
|
|
@ -16,7 +16,6 @@ import numpy as np
|
|||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
@ -36,7 +35,7 @@ def test_tf_skip():
|
|||
data1 = data1.skip(2)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 1
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
@ -22,7 +21,7 @@ from mindspore import log as logger
|
|||
|
||||
def gen():
|
||||
for i in range(100):
|
||||
yield np.array(i),
|
||||
yield (np.array(i),)
|
||||
|
||||
|
||||
class Augment:
|
||||
|
@ -38,7 +37,7 @@ class Augment:
|
|||
|
||||
def test_simple_sync_wait():
|
||||
"""
|
||||
Test simple sync wait: test sync in dataset pipeline
|
||||
Test simple sync wait: test sync in dataset pipeline
|
||||
"""
|
||||
logger.info("test_simple_sync_wait")
|
||||
batch_size = 4
|
||||
|
@ -51,7 +50,7 @@ def test_simple_sync_wait():
|
|||
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert (data["input"][0] == count)
|
||||
assert data["input"][0] == count
|
||||
count += batch_size
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="policy", data=data)
|
||||
|
@ -59,7 +58,7 @@ def test_simple_sync_wait():
|
|||
|
||||
def test_simple_shuffle_sync():
|
||||
"""
|
||||
Test simple shuffle sync: test shuffle before sync
|
||||
Test simple shuffle sync: test shuffle before sync
|
||||
"""
|
||||
logger.info("test_simple_shuffle_sync")
|
||||
shuffle_size = 4
|
||||
|
@ -83,7 +82,7 @@ def test_simple_shuffle_sync():
|
|||
|
||||
def test_two_sync():
|
||||
"""
|
||||
Test two sync: dataset pipeline with with two sync_operators
|
||||
Test two sync: dataset pipeline with with two sync_operators
|
||||
"""
|
||||
logger.info("test_two_sync")
|
||||
batch_size = 6
|
||||
|
@ -111,7 +110,7 @@ def test_two_sync():
|
|||
|
||||
def test_sync_epoch():
|
||||
"""
|
||||
Test sync wait with epochs: test sync with epochs in dataset pipeline
|
||||
Test sync wait with epochs: test sync with epochs in dataset pipeline
|
||||
"""
|
||||
logger.info("test_sync_epoch")
|
||||
batch_size = 30
|
||||
|
@ -122,11 +121,11 @@ def test_sync_epoch():
|
|||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
|
||||
for epochs in range(3):
|
||||
for _ in range(3):
|
||||
aug.update({"loss": 0})
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert (data["input"][0] == count)
|
||||
assert data["input"][0] == count
|
||||
count += batch_size
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="policy", data=data)
|
||||
|
@ -134,7 +133,7 @@ def test_sync_epoch():
|
|||
|
||||
def test_multiple_iterators():
|
||||
"""
|
||||
Test sync wait with multiple iterators: will start multiple
|
||||
Test sync wait with multiple iterators: will start multiple
|
||||
"""
|
||||
logger.info("test_sync_epoch")
|
||||
batch_size = 30
|
||||
|
@ -153,7 +152,7 @@ def test_multiple_iterators():
|
|||
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
|
||||
|
||||
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
|
||||
assert (item1["input"][0] == item2["input"][0])
|
||||
assert item1["input"][0] == item2["input"][0]
|
||||
data1 = {"loss": item1["input"][0]}
|
||||
data2 = {"loss": item2["input"][0]}
|
||||
dataset.sync_update(condition_name="policy", data=data1)
|
||||
|
@ -162,7 +161,7 @@ def test_multiple_iterators():
|
|||
|
||||
def test_sync_exception_01():
|
||||
"""
|
||||
Test sync: with shuffle in sync mode
|
||||
Test sync: with shuffle in sync mode
|
||||
"""
|
||||
logger.info("test_sync_exception_01")
|
||||
shuffle_size = 4
|
||||
|
@ -183,7 +182,7 @@ def test_sync_exception_01():
|
|||
|
||||
def test_sync_exception_02():
|
||||
"""
|
||||
Test sync: with duplicated condition name
|
||||
Test sync: with duplicated condition name
|
||||
"""
|
||||
logger.info("test_sync_exception_02")
|
||||
batch_size = 6
|
||||
|
|
|
@ -21,13 +21,13 @@ from mindspore import log as logger
|
|||
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
|
||||
def generator():
|
||||
for i in range(3):
|
||||
yield np.array([i]),
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
|
||||
def generator_10():
|
||||
for i in range(10):
|
||||
yield np.array([i]),
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def filter_func_ge(data):
|
||||
|
@ -47,8 +47,8 @@ def test_take_01():
|
|||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert 0 == d[0][0]
|
||||
for _, d in enumerate(data1):
|
||||
assert d[0][0] == 0
|
||||
|
||||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
|
@ -97,7 +97,7 @@ def test_take_04():
|
|||
data1 = data1.take(4)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
|
@ -113,7 +113,7 @@ def test_take_05():
|
|||
|
||||
data1 = data1.take(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i == d[0][0]
|
||||
|
||||
|
@ -130,7 +130,7 @@ def test_take_06():
|
|||
data1 = data1.repeat(2)
|
||||
data1 = data1.take(4)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
|
@ -171,7 +171,7 @@ def test_take_09():
|
|||
data1 = data1.repeat(2)
|
||||
data1 = data1.take(-1)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
|
@ -188,7 +188,7 @@ def test_take_10():
|
|||
data1 = data1.take(-1)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
|
@ -206,7 +206,7 @@ def test_take_11():
|
|||
data1 = data1.repeat(2)
|
||||
data1 = data1.take(-1)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert 2 * (i % 2) == d[0][0]
|
||||
|
||||
|
@ -224,9 +224,9 @@ def test_take_12():
|
|||
data1 = data1.batch(2)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert 0 == d[0][0]
|
||||
# Here i refers to index, d refers to data element
|
||||
for _, d in enumerate(data1):
|
||||
assert d[0][0] == 0
|
||||
|
||||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
|
@ -243,9 +243,9 @@ def test_take_13():
|
|||
data1 = data1.batch(2)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert 2 == d[0][0]
|
||||
# Here i refers to index, d refers to data element
|
||||
for _, d in enumerate(data1):
|
||||
assert d[0][0] == 2
|
||||
|
||||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
|
@ -262,9 +262,9 @@ def test_take_14():
|
|||
data1 = data1.skip(1)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert 2 == d[0][0]
|
||||
# Here i refers to index, d refers to data element
|
||||
for _, d in enumerate(data1):
|
||||
assert d[0][0] == 2
|
||||
|
||||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
|
@ -279,7 +279,7 @@ def test_take_15():
|
|||
data1 = data1.take(6)
|
||||
data1 = data1.skip(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert (i + 2) == d[0][0]
|
||||
|
||||
|
@ -296,7 +296,7 @@ def test_take_16():
|
|||
data1 = data1.skip(3)
|
||||
data1 = data1.take(5)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert (i + 3) == d[0][0]
|
||||
|
||||
|
@ -313,7 +313,7 @@ def test_take_17():
|
|||
data1 = data1.take(8)
|
||||
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i == d[0][0]
|
||||
|
||||
|
@ -334,9 +334,9 @@ def test_take_18():
|
|||
data1 = data1.batch(2)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert 2 == d[0][0]
|
||||
# Here i refers to index, d refers to data element
|
||||
for _, d in enumerate(data1):
|
||||
assert d[0][0] == 2
|
||||
|
||||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_case_tf_shape():
|
|||
for data in ds1.create_dict_iterator():
|
||||
logger.info(data)
|
||||
output_shape = ds1.output_shapes()
|
||||
assert (len(output_shape[-1]) == 1)
|
||||
assert len(output_shape[-1]) == 1
|
||||
|
||||
|
||||
def test_case_tf_read_all_dataset():
|
||||
|
@ -41,7 +41,7 @@ def test_case_tf_read_all_dataset():
|
|||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
assert ds1.get_dataset_size() == 12
|
||||
count = 0
|
||||
for data in ds1.create_tuple_iterator():
|
||||
for _ in ds1.create_tuple_iterator():
|
||||
count += 1
|
||||
assert count == 12
|
||||
|
||||
|
@ -51,7 +51,7 @@ def test_case_num_samples():
|
|||
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
|
||||
assert ds1.get_dataset_size() == 8
|
||||
count = 0
|
||||
for data in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 8
|
||||
|
||||
|
@ -61,7 +61,7 @@ def test_case_num_samples2():
|
|||
ds1 = ds.TFRecordDataset(FILES, schema_file)
|
||||
assert ds1.get_dataset_size() == 7
|
||||
count = 0
|
||||
for data in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 7
|
||||
|
||||
|
@ -70,7 +70,7 @@ def test_case_tf_shape_2():
|
|||
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
|
||||
ds1 = ds1.batch(2)
|
||||
output_shape = ds1.output_shapes()
|
||||
assert (len(output_shape[-1]) == 2)
|
||||
assert len(output_shape[-1]) == 2
|
||||
|
||||
|
||||
def test_case_tf_file():
|
||||
|
@ -175,10 +175,10 @@ def test_tf_record_shard():
|
|||
assert len(worker1_res) == 48
|
||||
assert len(worker1_res) == len(worker2_res)
|
||||
# check criteria 1
|
||||
for i in range(len(worker1_res)):
|
||||
assert (worker1_res[i] != worker2_res[i])
|
||||
for i, _ in enumerate(worker1_res):
|
||||
assert worker1_res[i] != worker2_res[i]
|
||||
# check criteria 2
|
||||
assert (set(worker2_res) == set(worker1_res))
|
||||
assert set(worker2_res) == set(worker1_res)
|
||||
|
||||
|
||||
def test_tf_shard_equal_rows():
|
||||
|
@ -197,16 +197,16 @@ def test_tf_shard_equal_rows():
|
|||
worker2_res = get_res(3, 1, 2)
|
||||
worker3_res = get_res(3, 2, 2)
|
||||
# check criteria 1
|
||||
for i in range(len(worker1_res)):
|
||||
assert (worker1_res[i] != worker2_res[i])
|
||||
assert (worker2_res[i] != worker3_res[i])
|
||||
for i, _ in enumerate(worker1_res):
|
||||
assert worker1_res[i] != worker2_res[i]
|
||||
assert worker2_res[i] != worker3_res[i]
|
||||
# Confirm each worker gets same number of rows
|
||||
assert len(worker1_res) == 28
|
||||
assert len(worker1_res) == len(worker2_res)
|
||||
assert len(worker2_res) == len(worker3_res)
|
||||
|
||||
worker4_res = get_res(1, 0, 1)
|
||||
assert (len(worker4_res) == 40)
|
||||
assert len(worker4_res) == 40
|
||||
|
||||
|
||||
def test_case_tf_file_no_schema_columns_list():
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_batch_corner_cases():
|
|||
# to a pyfunc which makes a deep copy of the row
|
||||
def test_variable_size_batch():
|
||||
def check_res(arr1, arr2):
|
||||
for ind in range(len(arr1)):
|
||||
for ind, _ in enumerate(arr1):
|
||||
if not np.array_equal(arr1[ind], np.array(arr2[ind])):
|
||||
return False
|
||||
return len(arr1) == len(arr2)
|
||||
|
@ -143,7 +143,7 @@ def test_variable_size_batch():
|
|||
|
||||
def test_basic_batch_map():
|
||||
def check_res(arr1, arr2):
|
||||
for ind in range(len(arr1)):
|
||||
for ind, _ in enumerate(arr1):
|
||||
if not np.array_equal(arr1[ind], np.array(arr2[ind])):
|
||||
return False
|
||||
return len(arr1) == len(arr2)
|
||||
|
@ -176,7 +176,7 @@ def test_basic_batch_map():
|
|||
|
||||
def test_batch_multi_col_map():
|
||||
def check_res(arr1, arr2):
|
||||
for ind in range(len(arr1)):
|
||||
for ind, _ in enumerate(arr1):
|
||||
if not np.array_equal(arr1[ind], np.array(arr2[ind])):
|
||||
return False
|
||||
return len(arr1) == len(arr2)
|
||||
|
@ -224,7 +224,7 @@ def test_batch_multi_col_map():
|
|||
|
||||
def test_var_batch_multi_col_map():
|
||||
def check_res(arr1, arr2):
|
||||
for ind in range(len(arr1)):
|
||||
for ind, _ in enumerate(arr1):
|
||||
if not np.array_equal(arr1[ind], np.array(arr2[ind])):
|
||||
return False
|
||||
return len(arr1) == len(arr2)
|
||||
|
@ -269,7 +269,7 @@ def test_var_batch_var_resize():
|
|||
return ([np.copy(c[0:s, 0:s, :]) for c in col],)
|
||||
|
||||
def add_one(batchInfo):
|
||||
return (batchInfo.get_batch_num() + 1)
|
||||
return batchInfo.get_batch_num() + 1
|
||||
|
||||
data1 = ds.ImageFolderDatasetV2("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True)
|
||||
data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize)
|
||||
|
@ -303,7 +303,7 @@ def test_exception():
|
|||
|
||||
data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func)
|
||||
try:
|
||||
for item in data2.create_dict_iterator():
|
||||
for _ in data2.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError:
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test mindrecord base"""
|
||||
import numpy as np
|
||||
import os
|
||||
import uuid
|
||||
import numpy as np
|
||||
from utils import get_data, get_nlp_data
|
||||
|
||||
from mindspore import log as logger
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""test write multiple images"""
|
||||
import numpy as np
|
||||
import os
|
||||
import numpy as np
|
||||
from utils import get_two_bytes_data, get_multi_bytes_data
|
||||
|
||||
from mindspore import log as logger
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
"""test mnist to mindrecord tool"""
|
||||
import gzip
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import log as logger
|
||||
|
|
Loading…
Reference in New Issue