!1478 [Dataset] clean pylint.

This commit is contained in:
Yang 2020-05-26 16:17:53 +08:00
parent c086d91aaf
commit 9b2a778d94
43 changed files with 304 additions and 289 deletions

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test dataset performance about mindspore.MindDataset, mindspore.TFRecordDataset, tf.data.TFRecordDataset""" """test dataset performance about mindspore.MindDataset, mindspore.TFRecordDataset, tf.data.TFRecordDataset"""
import tensorflow as tf
import time import time
import tensorflow as tf
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore.mindrecord import FileReader from mindspore.mindrecord import FileReader

View File

@ -32,9 +32,9 @@ def test_apply_generator_case():
data1 = ds.GeneratorDataset(generator_1d, ["data"]) data1 = ds.GeneratorDataset(generator_1d, ["data"])
data2 = ds.GeneratorDataset(generator_1d, ["data"]) data2 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds): def dataset_fn(ds_):
ds = ds.repeat(2) ds_ = ds_.repeat(2)
return ds.batch(4) return ds_.batch(4)
data1 = data1.apply(dataset_fn) data1 = data1.apply(dataset_fn)
data2 = data2.repeat(2) data2 = data2.repeat(2)
@ -52,11 +52,11 @@ def test_apply_imagefolder_case():
decode_op = vision.Decode() decode_op = vision.Decode()
normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0]) normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
def dataset_fn(ds): def dataset_fn(ds_):
ds = ds.map(operations=decode_op) ds_ = ds_.map(operations=decode_op)
ds = ds.map(operations=normalize_op) ds_ = ds_.map(operations=normalize_op)
ds = ds.repeat(2) ds_ = ds_.repeat(2)
return ds return ds_
data1 = data1.apply(dataset_fn) data1 = data1.apply(dataset_fn)
data2 = data2.map(operations=decode_op) data2 = data2.map(operations=decode_op)
@ -67,125 +67,125 @@ def test_apply_imagefolder_case():
assert np.array_equal(item1["image"], item2["image"]) assert np.array_equal(item1["image"], item2["image"])
def test_apply_flow_case_0(id=0): def test_apply_flow_case_0(id_=0):
# apply control flow operations # apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"]) data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds): def dataset_fn(ds_):
if id == 0: if id_ == 0:
ds = ds.batch(4) ds_ = ds_.batch(4)
elif id == 1: elif id_ == 1:
ds = ds.repeat(2) ds_ = ds_.repeat(2)
elif id == 2: elif id_ == 2:
ds = ds.batch(4) ds_ = ds_.batch(4)
ds = ds.repeat(2) ds_ = ds_.repeat(2)
else: else:
ds = ds.shuffle(buffer_size=4) ds_ = ds_.shuffle(buffer_size=4)
return ds return ds_
data1 = data1.apply(dataset_fn) data1 = data1.apply(dataset_fn)
num_iter = 0 num_iter = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_iter = num_iter + 1 num_iter = num_iter + 1
if id == 0: if id_ == 0:
assert num_iter == 16 assert num_iter == 16
elif id == 1: elif id_ == 1:
assert num_iter == 128 assert num_iter == 128
elif id == 2: elif id_ == 2:
assert num_iter == 32 assert num_iter == 32
else: else:
assert num_iter == 64 assert num_iter == 64
def test_apply_flow_case_1(id=1): def test_apply_flow_case_1(id_=1):
# apply control flow operations # apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"]) data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds): def dataset_fn(ds_):
if id == 0: if id_ == 0:
ds = ds.batch(4) ds_ = ds_.batch(4)
elif id == 1: elif id_ == 1:
ds = ds.repeat(2) ds_ = ds_.repeat(2)
elif id == 2: elif id_ == 2:
ds = ds.batch(4) ds_ = ds_.batch(4)
ds = ds.repeat(2) ds_ = ds_.repeat(2)
else: else:
ds = ds.shuffle(buffer_size=4) ds_ = ds_.shuffle(buffer_size=4)
return ds return ds_
data1 = data1.apply(dataset_fn) data1 = data1.apply(dataset_fn)
num_iter = 0 num_iter = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_iter = num_iter + 1 num_iter = num_iter + 1
if id == 0: if id_ == 0:
assert num_iter == 16 assert num_iter == 16
elif id == 1: elif id_ == 1:
assert num_iter == 128 assert num_iter == 128
elif id == 2: elif id_ == 2:
assert num_iter == 32 assert num_iter == 32
else: else:
assert num_iter == 64 assert num_iter == 64
def test_apply_flow_case_2(id=2): def test_apply_flow_case_2(id_=2):
# apply control flow operations # apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"]) data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds): def dataset_fn(ds_):
if id == 0: if id_ == 0:
ds = ds.batch(4) ds_ = ds_.batch(4)
elif id == 1: elif id_ == 1:
ds = ds.repeat(2) ds_ = ds_.repeat(2)
elif id == 2: elif id_ == 2:
ds = ds.batch(4) ds_ = ds_.batch(4)
ds = ds.repeat(2) ds_ = ds_.repeat(2)
else: else:
ds = ds.shuffle(buffer_size=4) ds_ = ds_.shuffle(buffer_size=4)
return ds return ds_
data1 = data1.apply(dataset_fn) data1 = data1.apply(dataset_fn)
num_iter = 0 num_iter = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_iter = num_iter + 1 num_iter = num_iter + 1
if id == 0: if id_ == 0:
assert num_iter == 16 assert num_iter == 16
elif id == 1: elif id_ == 1:
assert num_iter == 128 assert num_iter == 128
elif id == 2: elif id_ == 2:
assert num_iter == 32 assert num_iter == 32
else: else:
assert num_iter == 64 assert num_iter == 64
def test_apply_flow_case_3(id=3): def test_apply_flow_case_3(id_=3):
# apply control flow operations # apply control flow operations
data1 = ds.GeneratorDataset(generator_1d, ["data"]) data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds): def dataset_fn(ds_):
if id == 0: if id_ == 0:
ds = ds.batch(4) ds_ = ds_.batch(4)
elif id == 1: elif id_ == 1:
ds = ds.repeat(2) ds_ = ds_.repeat(2)
elif id == 2: elif id_ == 2:
ds = ds.batch(4) ds_ = ds_.batch(4)
ds = ds.repeat(2) ds_ = ds_.repeat(2)
else: else:
ds = ds.shuffle(buffer_size=4) ds_ = ds_.shuffle(buffer_size=4)
return ds return ds_
data1 = data1.apply(dataset_fn) data1 = data1.apply(dataset_fn)
num_iter = 0 num_iter = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_iter = num_iter + 1 num_iter = num_iter + 1
if id == 0: if id_ == 0:
assert num_iter == 16 assert num_iter == 16
elif id == 1: elif id_ == 1:
assert num_iter == 128 assert num_iter == 128
elif id == 2: elif id_ == 2:
assert num_iter == 32 assert num_iter == 32
else: else:
assert num_iter == 64 assert num_iter == 64
@ -195,11 +195,11 @@ def test_apply_exception_case():
# apply exception operations # apply exception operations
data1 = ds.GeneratorDataset(generator_1d, ["data"]) data1 = ds.GeneratorDataset(generator_1d, ["data"])
def dataset_fn(ds): def dataset_fn(ds_):
ds = ds.repeat(2) ds_ = ds_.repeat(2)
return ds.batch(4) return ds_.batch(4)
def exception_fn(ds): def exception_fn():
return np.array([[0], [1], [3], [4], [5]]) return np.array([[0], [1], [3], [4], [5]])
try: try:
@ -220,12 +220,12 @@ def test_apply_exception_case():
try: try:
data2 = data1.apply(dataset_fn) data2 = data1.apply(dataset_fn)
data3 = data1.apply(dataset_fn) _ = data1.apply(dataset_fn)
for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
pass pass
assert False assert False
except ValueError: except ValueError as e:
pass logger.info("Got an exception in DE: {}".format(str(e)))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -58,7 +58,7 @@ def test_auto_contrast(plot=False):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -79,7 +79,7 @@ def test_auto_contrast(plot=False):
ds_auto_contrast = ds_auto_contrast.batch(512) ds_auto_contrast = ds_auto_contrast.batch(512)
for idx, (image, label) in enumerate(ds_auto_contrast): for idx, (image, _) in enumerate(ds_auto_contrast):
if idx == 0: if idx == 0:
images_auto_contrast = np.transpose(image, (0, 2, 3, 1)) images_auto_contrast = np.transpose(image, (0, 2, 3, 1))
else: else:

View File

@ -273,7 +273,7 @@ def test_batch_exception_01():
data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0) data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_parallel_workers" in str(e) assert "num_parallel_workers" in str(e)
@ -290,7 +290,7 @@ def test_batch_exception_02():
data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1) data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_parallel_workers" in str(e) assert "num_parallel_workers" in str(e)
@ -307,7 +307,7 @@ def test_batch_exception_03():
data1 = data1.batch(batch_size=0) data1 = data1.batch(batch_size=0)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e) assert "batch_size" in str(e)
@ -324,7 +324,7 @@ def test_batch_exception_04():
data1 = data1.batch(batch_size=-1) data1 = data1.batch(batch_size=-1)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e) assert "batch_size" in str(e)
@ -341,7 +341,7 @@ def test_batch_exception_05():
data1 = data1.batch(batch_size=False) data1 = data1.batch(batch_size=False)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e) assert "batch_size" in str(e)
@ -358,7 +358,7 @@ def test_batch_exception_07():
data1 = data1.batch(3, drop_remainder=0) data1 = data1.batch(3, drop_remainder=0)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "drop_remainder" in str(e) assert "drop_remainder" in str(e)
@ -375,7 +375,7 @@ def test_batch_exception_08():
data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False) data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_parallel_workers" in str(e) assert "num_parallel_workers" in str(e)
@ -392,7 +392,7 @@ def test_batch_exception_09():
data1 = data1.batch(drop_remainder=True, num_parallel_workers=4) data1 = data1.batch(drop_remainder=True, num_parallel_workers=4)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e) assert "batch_size" in str(e)
@ -409,7 +409,7 @@ def test_batch_exception_10():
data1 = data1.batch(batch_size=4, num_parallel_workers=8192) data1 = data1.batch(batch_size=4, num_parallel_workers=8192)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_parallel_workers" in str(e) assert "num_parallel_workers" in str(e)
@ -429,7 +429,7 @@ def test_batch_exception_11():
data1 = data1.batch(batch_size, num_parallel_workers) data1 = data1.batch(batch_size, num_parallel_workers)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "drop_remainder" in str(e) assert "drop_remainder" in str(e)
@ -450,7 +450,7 @@ def test_batch_exception_12():
data1 = data1.batch(drop_remainder, batch_size=batch_size) data1 = data1.batch(drop_remainder, batch_size=batch_size)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e) assert "batch_size" in str(e)
@ -469,7 +469,7 @@ def test_batch_exception_13():
data1 = data1.batch(batch_size, shard_id=1) data1 = data1.batch(batch_size, shard_id=1)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "shard_id" in str(e) assert "shard_id" in str(e)

View File

@ -24,18 +24,18 @@ from mindspore import log as logger
# In generator dataset: Number of rows is 3; its values are 0, 1, 2 # In generator dataset: Number of rows is 3; its values are 0, 1, 2
def generator(): def generator():
for i in range(3): for i in range(3):
yield np.array([i]), yield (np.array([i]),)
# In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9 # In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9
def generator_10(): def generator_10():
for i in range(3, 10): for i in range(3, 10):
yield np.array([i]), yield (np.array([i]),)
# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19 # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
def generator_20(): def generator_20():
for i in range(10, 20): for i in range(10, 20):
yield np.array([i]), yield (np.array([i]),)
def test_concat_01(): def test_concat_01():
@ -85,7 +85,7 @@ def test_concat_03():
data3 = data1 + data2 data3 = data1 + data2
try: try:
for i, d in enumerate(data3): for _, _ in enumerate(data3):
pass pass
assert False assert False
except RuntimeError: except RuntimeError:
@ -104,7 +104,7 @@ def test_concat_04():
data3 = data1 + data2 data3 = data1 + data2
try: try:
for i, d in enumerate(data3): for _, _ in enumerate(data3):
pass pass
assert False assert False
except RuntimeError: except RuntimeError:
@ -125,7 +125,7 @@ def test_concat_05():
data3 = data1 + data2 data3 = data1 + data2
try: try:
for i, d in enumerate(data3): for _, _ in enumerate(data3):
pass pass
assert False assert False
except RuntimeError: except RuntimeError:

View File

@ -31,7 +31,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_basic(): def test_basic():
""" """
Test basic configuration functions Test basic configuration functions
""" """
# Save original configuration values # Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers() num_parallel_workers_original = ds.config.get_num_parallel_workers()
@ -138,7 +138,7 @@ def test_deterministic_run_fail():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"]) np.testing.assert_equal(item1["image"], item2["image"])
except BaseException as e: except Exception as e:
# two datasets split the number out of the sequence a # two datasets split the number out of the sequence a
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e) assert "Array" in str(e)
@ -157,7 +157,7 @@ def test_deterministic_run_pass():
# Save original configuration values # Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers() num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed() seed_original = ds.config.get_seed()
ds.config.set_seed(0) ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1) ds.config.set_num_parallel_workers(1)
@ -179,7 +179,7 @@ def test_deterministic_run_pass():
try: try:
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"]) np.testing.assert_equal(item1["image"], item2["image"])
except BaseException as e: except Exception as e:
# two datasets both use numbers from the generated sequence "a" # two datasets both use numbers from the generated sequence "a"
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e) assert "Array" in str(e)
@ -344,7 +344,7 @@ def test_deterministic_python_seed_multi_thread():
try: try:
np.testing.assert_equal(data1_output, data2_output) np.testing.assert_equal(data1_output, data2_output)
except BaseException as e: except Exception as e:
# expect output to not match during multi-threaded excution # expect output to not match during multi-threaded excution
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e) assert "Array" in str(e)

View File

@ -107,14 +107,20 @@ def test_tfrecord_shardings4(print_res=False):
assert len(result_list) == expect_length assert len(result_list) == expect_length
assert set(result_list) == expect_set assert set(result_list) == expect_set
check_result(sharding_config(2, 0, None, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) check_result(sharding_config(2, 0, None, 1), 20,
check_result(sharding_config(2, 1, None, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
check_result(sharding_config(2, 1, None, 1), 20,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21}) check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21})
check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31}) check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31})
check_result(sharding_config(2, 0, 40, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) check_result(sharding_config(2, 0, 40, 1), 20,
check_result(sharding_config(2, 1, 40, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
check_result(sharding_config(2, 0, 55, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) check_result(sharding_config(2, 1, 40, 1), 20,
check_result(sharding_config(2, 1, 55, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
check_result(sharding_config(2, 0, 55, 1), 20,
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
check_result(sharding_config(2, 1, 55, 1), 20,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31}) check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31})
check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8}) check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8})
check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28}) check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28})

View File

@ -49,7 +49,7 @@ def test_textline_dataset_totext():
strs = i["text"].item().decode("utf8") strs = i["text"].item().decode("utf8")
assert strs == line[count] assert strs == line[count]
count += 1 count += 1
assert (count == 5) assert count == 5
# Restore configuration num_parallel_workers # Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers) ds.config.set_num_parallel_workers(original_num_parallel_workers)

View File

@ -24,10 +24,10 @@ def test_voc_segmentation():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
num = 0 num = 0
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
assert (item["image"].shape[0] == IMAGE_SHAPE[num]) assert item["image"].shape[0] == IMAGE_SHAPE[num]
assert (item["target"].shape[0] == TARGET_SHAPE[num]) assert item["target"].shape[0] == TARGET_SHAPE[num]
num += 1 num += 1
assert (num == 10) assert num == 10
def test_voc_detection(): def test_voc_detection():
@ -35,12 +35,12 @@ def test_voc_detection():
num = 0 num = 0
count = [0, 0, 0, 0, 0, 0] count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
assert (item["image"].shape[0] == IMAGE_SHAPE[num]) assert item["image"].shape[0] == IMAGE_SHAPE[num]
for bbox in item["annotation"]: for bbox in item["annotation"]:
count[bbox[0]] += 1 count[bbox[0]] += 1
num += 1 num += 1
assert (num == 9) assert num == 9
assert (count == [3, 2, 1, 2, 4, 3]) assert count == [3, 2, 1, 2, 4, 3]
def test_voc_class_index(): def test_voc_class_index():
@ -58,8 +58,8 @@ def test_voc_class_index():
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5) assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5)
count[bbox[0]] += 1 count[bbox[0]] += 1
num += 1 num += 1
assert (num == 6) assert num == 6
assert (count == [3, 2, 0, 0, 0, 3]) assert count == [3, 2, 0, 0, 0, 3]
def test_voc_get_class_indexing(): def test_voc_get_class_indexing():
@ -76,8 +76,8 @@ def test_voc_get_class_indexing():
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5) assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5)
count[bbox[0]] += 1 count[bbox[0]] += 1
num += 1 num += 1
assert (num == 9) assert num == 9
assert (count == [3, 2, 1, 2, 4, 3]) assert count == [3, 2, 1, 2, 4, 3]
def test_case_0(): def test_case_0():
@ -93,9 +93,9 @@ def test_case_0():
data1 = data1.batch(batch_size, drop_remainder=True) data1 = data1.batch(batch_size, drop_remainder=True)
num = 0 num = 0
for item in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num += 1 num += 1
assert (num == 20) assert num == 20
def test_case_1(): def test_case_1():
@ -110,9 +110,9 @@ def test_case_1():
data1 = data1.batch(batch_size, drop_remainder=True, pad_info={}) data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
num = 0 num = 0
for item in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num += 1 num += 1
assert (num == 18) assert num == 18
def test_voc_exception(): def test_voc_exception():

View File

@ -58,7 +58,7 @@ def test_equalize(plot=False):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -79,7 +79,7 @@ def test_equalize(plot=False):
ds_equalize = ds_equalize.batch(512) ds_equalize = ds_equalize.batch(512)
for idx, (image, label) in enumerate(ds_equalize): for idx, (image, _) in enumerate(ds_equalize):
if idx == 0: if idx == 0:
images_equalize = np.transpose(image, (0, 2, 3, 1)) images_equalize = np.transpose(image, (0, 2, 3, 1))
else: else:

View File

@ -15,9 +15,7 @@
import numpy as np import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as cde import mindspore.dataset.transforms.vision.c_transforms as cde
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -31,7 +29,6 @@ def test_diff_predicate_func():
cde.Decode(), cde.Decode(),
cde.Resize([64, 64]) cde.Resize([64, 64])
] ]
type_cast_op = C.TypeCast(mstype.int32)
dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False) dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False)
dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1) dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1)
dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4) dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4)
@ -40,7 +37,6 @@ def test_diff_predicate_func():
label_list = [] label_list = []
for data in dataset.create_dict_iterator(): for data in dataset.create_dict_iterator():
num_iter += 1 num_iter += 1
ori_img = data["image"]
label = data["label"] label = data["label"]
label_list.append(label) label_list.append(label)
assert num_iter == 1 assert num_iter == 1
@ -200,6 +196,7 @@ def generator_1d_zip2():
def filter_func_zip(data1, data2): def filter_func_zip(data1, data2):
_ = data2
if data1 > 20: if data1 > 20:
return False return False
return True return True
@ -249,6 +246,7 @@ def test_filter_by_generator_with_zip_after():
def filter_func_map(col1, col2): def filter_func_map(col1, col2):
_ = col2
if col1[0] > 8: if col1[0] > 8:
return True return True
return False return False
@ -262,6 +260,7 @@ def filter_func_map_part(col1):
def filter_func_map_all(col1, col2): def filter_func_map_all(col1, col2):
_, _ = col1, col2
return True return True
@ -334,6 +333,7 @@ def test_filter_by_generator_with_rename():
# test input_column # test input_column
def filter_func_input_column1(col1, col2): def filter_func_input_column1(col1, col2):
_ = col2
if col1[0] < 8: if col1[0] < 8:
return True return True
return False return False
@ -346,6 +346,7 @@ def filter_func_input_column2(col1):
def filter_func_input_column3(col1): def filter_func_input_column3(col1):
_ = col1
return True return True
@ -380,6 +381,7 @@ def generator_mc_p1(maxid=20):
def filter_func_Partial_0(col1, col2, col3, col4): def filter_func_Partial_0(col1, col2, col3, col4):
_, _, _ = col2, col3, col4
filter_data = [0, 1, 2, 3, 4, 11] filter_data = [0, 1, 2, 3, 4, 11]
if col1[0] in filter_data: if col1[0] in filter_data:
return False return False
@ -439,6 +441,7 @@ def test_filter_by_generator_Partial2():
def filter_func_Partial(col1, col2): def filter_func_Partial(col1, col2):
_ = col2
if col1[0] % 3 == 0: if col1[0] % 3 == 0:
return True return True
return False return False
@ -461,6 +464,7 @@ def test_filter_by_generator_Partial():
def filter_func_cifar(col1, col2): def filter_func_cifar(col1, col2):
_ = col1
if col2 % 3 == 0: if col2 % 3 == 0:
return True return True
return False return False
@ -490,6 +494,7 @@ def generator_sort2(maxid=20):
def filter_func_part_sort(col1, col2, col3, col4, col5, col6): def filter_func_part_sort(col1, col2, col3, col4, col5, col6):
_, _, _, _, _, _ = col1, col2, col3, col4, col5, col6
return True return True

View File

@ -58,7 +58,7 @@ def test_invert(plot=False):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -79,7 +79,7 @@ def test_invert(plot=False):
ds_invert = ds_invert.batch(512) ds_invert = ds_invert.batch(512)
for idx, (image, label) in enumerate(ds_invert): for idx, (image, _) in enumerate(ds_invert):
if idx == 0: if idx == 0:
images_invert = np.transpose(image, (0, 2, 3, 1)) images_invert = np.transpose(image, (0, 2, 3, 1))
else: else:

View File

@ -17,11 +17,11 @@ This is the test module for mindrecord
""" """
import collections import collections
import json import json
import numpy as np
import os import os
import pytest
import re import re
import string import string
import pytest
import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
@ -46,9 +46,10 @@ def add_and_remove_cv_file():
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
for x in paths: for x in paths:
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None if os.path.exists("{}".format(x)):
os.remove("{}.db".format(x)) if os.path.exists( os.remove("{}".format(x))
"{}.db".format(x)) else None if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
writer = FileWriter(CV_FILE_NAME, FILES_NUM) writer = FileWriter(CV_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME) data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"}, cv_schema_json = {"id": {"type": "int32"},
@ -117,7 +118,9 @@ def add_and_remove_nlp_compress_file():
255, 256, -32768, 32767, -32769, 32768, -2147483648, 255, 256, -32768, 32767, -32769, 32768, -2147483648,
2147483647], dtype=np.int32), [-1]), 2147483647], dtype=np.int32), [-1]),
"array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), 256, -32768, 32767, -32769, 32768,
-2147483648, 2147483647, -2147483649, 2147483649,
-922337036854775808, 9223372036854775807]), [1, -1]),
"array_c": str.encode("nlp data"), "array_c": str.encode("nlp data"),
"array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
}) })
@ -151,7 +154,9 @@ def test_nlp_compress_data(add_and_remove_nlp_compress_file):
255, 256, -32768, 32767, -32769, 32768, -2147483648, 255, 256, -32768, 32767, -32769, 32768, -2147483648,
2147483647], dtype=np.int32), [-1]), 2147483647], dtype=np.int32), [-1]),
"array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), 256, -32768, 32767, -32769, 32768,
-2147483648, 2147483647, -2147483649, 2147483649,
-922337036854775808, 9223372036854775807]), [1, -1]),
"array_c": str.encode("nlp data"), "array_c": str.encode("nlp data"),
"array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
}) })
@ -194,9 +199,10 @@ def test_cv_minddataset_writer_tutorial():
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
for x in paths: for x in paths:
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None if os.path.exists("{}".format(x)):
os.remove("{}.db".format(x)) if os.path.exists( os.remove("{}".format(x))
"{}.db".format(x)) else None if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
writer = FileWriter(CV_FILE_NAME, FILES_NUM) writer = FileWriter(CV_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME) data = get_data(CV_DIR_NAME)
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
@ -478,9 +484,10 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
for x in paths: for x in paths:
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None if os.path.exists("{}".format(x)):
os.remove("{}.db".format(x)) if os.path.exists( os.remove("{}".format(x))
"{}.db".format(x)) else None if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
writer = FileWriter(CV1_FILE_NAME, FILES_NUM) writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME) data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"}, cv_schema_json = {"id": {"type": "int32"},
@ -779,7 +786,7 @@ def get_nlp_data(dir_name, vocab_file, num):
""" """
if not os.path.isdir(dir_name): if not os.path.isdir(dir_name):
raise IOError("Directory {} not exists".format(dir_name)) raise IOError("Directory {} not exists".format(dir_name))
for root, dirs, files in os.walk(dir_name): for root, _, files in os.walk(dir_name):
for index, file_name_extension in enumerate(files): for index, file_name_extension in enumerate(files):
if index < num: if index < num:
file_path = os.path.join(root, file_name_extension) file_path = os.path.join(root, file_name_extension)
@ -851,7 +858,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
if os.path.exists("{}".format(mindrecord_file_name)): if os.path.exists("{}".format(mindrecord_file_name)):
os.remove("{}".format(mindrecord_file_name)) os.remove("{}".format(mindrecord_file_name))
if os.path.exists("{}.db".format(mindrecord_file_name)): if os.path.exists("{}.db".format(mindrecord_file_name)):
os.remove("{}.db".format(x)) os.remove("{}.db".format(mindrecord_file_name))
data = [{"file_name": "001.jpg", "label": 4, data = [{"file_name": "001.jpg", "label": 4,
"image1": bytes("image1 bytes abc", encoding='UTF-8'), "image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'), "image2": bytes("image1 bytes def", encoding='UTF-8'),

View File

@ -26,8 +26,10 @@ CV1_FILE_NAME = "./imagenet1.mindrecord"
def create_cv_mindrecord(files_num): def create_cv_mindrecord(files_num):
"""tutorial for cv dataset writer.""" """tutorial for cv dataset writer."""
os.remove(CV_FILE_NAME) if os.path.exists(CV_FILE_NAME) else None if os.path.exists(CV_FILE_NAME):
os.remove("{}.db".format(CV_FILE_NAME)) if os.path.exists("{}.db".format(CV_FILE_NAME)) else None os.remove(CV_FILE_NAME)
if os.path.exists("{}.db".format(CV_FILE_NAME)):
os.remove("{}.db".format(CV_FILE_NAME))
writer = FileWriter(CV_FILE_NAME, files_num) writer = FileWriter(CV_FILE_NAME, files_num)
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
@ -39,8 +41,10 @@ def create_cv_mindrecord(files_num):
def create_diff_schema_cv_mindrecord(files_num): def create_diff_schema_cv_mindrecord(files_num):
"""tutorial for cv dataset writer.""" """tutorial for cv dataset writer."""
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None if os.path.exists(CV1_FILE_NAME):
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None os.remove(CV1_FILE_NAME)
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
os.remove("{}.db".format(CV1_FILE_NAME))
writer = FileWriter(CV1_FILE_NAME, files_num) writer = FileWriter(CV1_FILE_NAME, files_num)
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
@ -52,8 +56,10 @@ def create_diff_schema_cv_mindrecord(files_num):
def create_diff_page_size_cv_mindrecord(files_num): def create_diff_page_size_cv_mindrecord(files_num):
"""tutorial for cv dataset writer.""" """tutorial for cv dataset writer."""
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None if os.path.exists(CV1_FILE_NAME):
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None os.remove(CV1_FILE_NAME)
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
os.remove("{}.db".format(CV1_FILE_NAME))
writer = FileWriter(CV1_FILE_NAME, files_num) writer = FileWriter(CV1_FILE_NAME, files_num)
writer.set_page_size(1 << 26) # 64MB writer.set_page_size(1 << 26) # 64MB
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
@ -69,8 +75,8 @@ def test_cv_lack_json():
create_cv_mindrecord(1) create_cv_mindrecord(1)
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception) as err: with pytest.raises(Exception):
data_set = ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers) ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers)
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -80,7 +86,7 @@ def test_cv_lack_mindrecord():
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="does not exist or permission denied"): with pytest.raises(Exception, match="does not exist or permission denied"):
data_set = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers) _ = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers)
def test_invalid_mindrecord(): def test_invalid_mindrecord():
@ -134,7 +140,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
sampler=sampler, shuffle=False) sampler=sampler, shuffle=False)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -149,7 +155,7 @@ def test_cv_minddataset_reader_different_schema():
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers) num_readers)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -166,7 +172,7 @@ def test_cv_minddataset_reader_different_page_size():
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers) num_readers)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -181,7 +187,7 @@ def test_minddataset_invalidate_num_shards():
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -194,7 +200,7 @@ def test_minddataset_invalidate_shard_id():
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -207,13 +213,13 @@ def test_minddataset_shard_id_bigger_than_num_shard():
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)

View File

@ -50,7 +50,7 @@ def test_cv_minddataset_reader_multi_image_and_ndarray_tutorial():
assert os.path.exists(CV_FILE_NAME) assert os.path.exists(CV_FILE_NAME)
assert os.path.exists(CV_FILE_NAME + ".db") assert os.path.exists(CV_FILE_NAME + ".db")
"""tutorial for minderdataset.""" # tutorial for minderdataset.
columns_list = ["id", "image_0", "image_2", "image_3", "image_4", "input_mask", "segments"] columns_list = ["id", "image_0", "image_2", "image_3", "image_4", "input_mask", "segments"]
num_readers = 1 num_readers = 1
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)

View File

@ -20,7 +20,6 @@ import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from mindspore.dataset.transforms.vision import Inter
from mindspore.dataset.text import to_str from mindspore.dataset.text import to_str
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter

View File

@ -39,7 +39,7 @@ def test_on_tokenized_line():
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
for i, d in enumerate(data.create_dict_iterator()): for i, d in enumerate(data.create_dict_iterator()):
np.testing.assert_array_equal(d["text"], res[i]), i _ = (np.testing.assert_array_equal(d["text"], res[i]), i)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -199,7 +199,7 @@ def test_jieba_5():
def gen(): def gen():
text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S') text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S')
yield text, yield (text,)
def pytoken_op(input_data): def pytoken_op(input_data):

View File

@ -109,10 +109,9 @@ def test_decode_op():
data1 = data1.map(input_columns=["image"], operations=decode_op) data1 = data1.map(input_columns=["image"], operations=decode_op)
num_iter = 0 num_iter = 0
image = None
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
logger.info("Looping inside iterator {}".format(num_iter)) logger.info("Looping inside iterator {}".format(num_iter))
image = item["image"] _ = item["image"]
# plt.subplot(131) # plt.subplot(131)
# plt.imshow(image) # plt.imshow(image)
# plt.title("DE image") # plt.title("DE image")
@ -134,10 +133,9 @@ def test_decode_normalize_op():
data1 = data1.map(input_columns=["image"], operations=[decode_op, normalize_op]) data1 = data1.map(input_columns=["image"], operations=[decode_op, normalize_op])
num_iter = 0 num_iter = 0
image = None
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
logger.info("Looping inside iterator {}".format(num_iter)) logger.info("Looping inside iterator {}".format(num_iter))
image = item["image"] _ = item["image"]
# plt.subplot(131) # plt.subplot(131)
# plt.imshow(image) # plt.imshow(image)
# plt.title("DE image") # plt.title("DE image")

View File

@ -37,8 +37,7 @@ def test_case_0():
data1 = data1.batch(2) data1 = data1.batch(2)
i = 0 for _ in data1.create_dict_iterator(): # each data is a dictionary
for item in data1.create_dict_iterator(): # each data is a dictionary
pass pass

View File

@ -72,7 +72,7 @@ def test_pad_op():
# pylint: disable=unnecessary-lambda # pylint: disable=unnecessary-lambda
def test_pad_grayscale(): def test_pad_grayscale():
""" """
Tests that the pad works for grayscale images Tests that the pad works for grayscale images
""" """
def channel_swap(image): def channel_swap(image):
@ -92,7 +92,7 @@ def test_pad_grayscale():
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(input_columns=["image"], operations=transform()) data1 = data1.map(input_columns=["image"], operations=transform())
# if input is grayscale, the output dimensions should be single channel # if input is grayscale, the output dimensions should be single channel
pad_gray = c_vision.Pad(100, fill_value=(20, 20, 20)) pad_gray = c_vision.Pad(100, fill_value=(20, 20, 20))
data1 = data1.map(input_columns=["image"], operations=pad_gray) data1 = data1.map(input_columns=["image"], operations=pad_gray)
dataset_shape_1 = [] dataset_shape_1 = []
@ -100,11 +100,11 @@ def test_pad_grayscale():
c_image = item1["image"] c_image = item1["image"]
dataset_shape_1.append(c_image.shape) dataset_shape_1.append(c_image.shape)
# Dataset for comparison # Dataset for comparison
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode() decode_op = c_vision.Decode()
# we use the same padding logic # we use the same padding logic
ctrans = [decode_op, pad_gray] ctrans = [decode_op, pad_gray]
dataset_shape_2 = [] dataset_shape_2 = []

View File

@ -119,7 +119,7 @@ def batch_padding_performance_3d():
num_batches = 0 num_batches = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -135,7 +135,7 @@ def batch_padding_performance_1d():
num_batches = 0 num_batches = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -151,7 +151,7 @@ def batch_pyfunc_padding_3d():
num_batches = 0 num_batches = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)
@ -166,7 +166,7 @@ def batch_pyfunc_padding_1d():
num_batches = 0 num_batches = 0
for _ in data1.create_dict_iterator(): for _ in data1.create_dict_iterator():
num_batches += 1 num_batches += 1
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
# print(res) # print(res)

View File

@ -58,7 +58,7 @@ def test_random_color(degrees=(0.1, 1.9), plot=False):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -79,7 +79,7 @@ def test_random_color(degrees=(0.1, 1.9), plot=False):
ds_random_color = ds_random_color.batch(512) ds_random_color = ds_random_color.batch(512)
for idx, (image, label) in enumerate(ds_random_color): for idx, (image, _) in enumerate(ds_random_color):
if idx == 0: if idx == 0:
images_random_color = np.transpose(image, (0, 2, 3, 1)) images_random_color = np.transpose(image, (0, 2, 3, 1))
else: else:

View File

@ -256,7 +256,7 @@ def test_random_color_adjust_op_hue(plot=False):
# pylint: disable=unnecessary-lambda # pylint: disable=unnecessary-lambda
def test_random_color_adjust_grayscale(): def test_random_color_adjust_grayscale():
""" """
Tests that the random color adjust works for grayscale images Tests that the random color adjust works for grayscale images
""" """
def channel_swap(image): def channel_swap(image):
@ -284,7 +284,7 @@ def test_random_color_adjust_grayscale():
for item1 in data1.create_dict_iterator(): for item1 in data1.create_dict_iterator():
c_image = item1["image"] c_image = item1["image"]
dataset_shape_1.append(c_image.shape) dataset_shape_1.append(c_image.shape)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))

View File

@ -200,7 +200,7 @@ def test_random_crop_04_c():
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
image = item["image"] image = item["image"]
image_list.append(image.shape) image_list.append(image.shape)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
def test_random_crop_04_py(): def test_random_crop_04_py():
@ -227,7 +227,7 @@ def test_random_crop_04_py():
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
image = (item["image"].transpose(1, 2, 0) * 255).astype(np.uint8) image = (item["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
image_list.append(image.shape) image_list.append(image.shape)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
def test_random_crop_05_c(): def test_random_crop_05_c():
@ -439,7 +439,7 @@ def test_random_crop_09():
for item in data.create_dict_iterator(): for item in data.create_dict_iterator():
image = item["image"] image = item["image"]
image_list.append(image.shape) image_list.append(image.shape)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "should be PIL Image" in str(e) assert "should be PIL Image" in str(e)

View File

@ -60,7 +60,7 @@ def test_random_resize_op():
num_iter = 0 num_iter = 0
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
image_de_resized = item["image"] _ = item["image"]
# Uncomment below line if you want to visualize images # Uncomment below line if you want to visualize images
# visualize(image_de_resized, image_np_resized, mse) # visualize(image_de_resized, image_np_resized, mse)
num_iter += 1 num_iter += 1

View File

@ -58,7 +58,7 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -79,7 +79,7 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
ds_random_sharpness = ds_random_sharpness.batch(512) ds_random_sharpness = ds_random_sharpness.batch(512)
for idx, (image, label) in enumerate(ds_random_sharpness): for idx, (image, _) in enumerate(ds_random_sharpness):
if idx == 0: if idx == 0:
images_random_sharpness = np.transpose(image, (0, 2, 3, 1)) images_random_sharpness = np.transpose(image, (0, 2, 3, 1))
else: else:

View File

@ -25,7 +25,7 @@ from mindspore import log as logger
def test_sequential_sampler(print_res=False): def test_sequential_sampler(print_res=False):
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, num_repeats=None): def test_config(num_samples, num_repeats=None):
sampler = ds.SequentialSampler() sampler = ds.SequentialSampler()
@ -36,7 +36,7 @@ def test_sequential_sampler(print_res=False):
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}" logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item())) .format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())]) res.append(map_[(item["image"].shape[0], item["label"].item())])
if print_res: if print_res:
logger.info("image.shapes and labels: {}".format(res)) logger.info("image.shapes and labels: {}".format(res))
return res return res
@ -48,7 +48,7 @@ def test_sequential_sampler(print_res=False):
def test_random_sampler(print_res=False): def test_random_sampler(print_res=False):
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(replacement, num_samples, num_repeats): def test_config(replacement, num_samples, num_repeats):
sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples) sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
@ -56,7 +56,7 @@ def test_random_sampler(print_res=False):
data1 = data1.repeat(num_repeats) data1 = data1.repeat(num_repeats)
res = [] res = []
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
res.append(map[(item["image"].shape[0], item["label"].item())]) res.append(map_[(item["image"].shape[0], item["label"].item())])
if print_res: if print_res:
logger.info("image.shapes and labels: {}".format(res)) logger.info("image.shapes and labels: {}".format(res))
return res return res
@ -71,7 +71,7 @@ def test_random_sampler(print_res=False):
def test_random_sampler_multi_iter(print_res=False): def test_random_sampler_multi_iter(print_res=False):
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(replacement, num_samples, num_repeats, validate): def test_config(replacement, num_samples, num_repeats, validate):
sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples) sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
@ -79,7 +79,7 @@ def test_random_sampler_multi_iter(print_res=False):
while num_repeats > 0: while num_repeats > 0:
res = [] res = []
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
res.append(map[(item["image"].shape[0], item["label"].item())]) res.append(map_[(item["image"].shape[0], item["label"].item())])
if print_res: if print_res:
logger.info("image.shapes and labels: {}".format(res)) logger.info("image.shapes and labels: {}".format(res))
if validate != sorted(res): if validate != sorted(res):
@ -112,7 +112,7 @@ def test_sampler_py_api():
def test_python_sampler(): def test_python_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
class Sp1(ds.Sampler): class Sp1(ds.Sampler):
def __iter__(self): def __iter__(self):
@ -138,7 +138,7 @@ def test_python_sampler():
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}" logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item())) .format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())]) res.append(map_[(item["image"].shape[0], item["label"].item())])
# print(res) # print(res)
return res return res
@ -167,7 +167,7 @@ def test_python_sampler():
def test_subset_sampler(): def test_subset_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, start_index, subset_size): def test_config(num_samples, start_index, subset_size):
sampler = ds.SubsetSampler(start_index, subset_size) sampler = ds.SubsetSampler(start_index, subset_size)
@ -175,7 +175,7 @@ def test_subset_sampler():
res = [] res = []
for item in d.create_dict_iterator(): for item in d.create_dict_iterator():
res.append(map[(item["image"].shape[0], item["label"].item())]) res.append(map_[(item["image"].shape[0], item["label"].item())])
return res return res
@ -196,7 +196,7 @@ def test_subset_sampler():
def test_sampler_chain(): def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_shards, shard_id): def test_config(num_shards, shard_id):
sampler = ds.DistributedSampler(num_shards, shard_id, False) sampler = ds.DistributedSampler(num_shards, shard_id, False)
@ -209,7 +209,7 @@ def test_sampler_chain():
for item in data1.create_dict_iterator(): for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}" logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item())) .format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())]) res.append(map_[(item["image"].shape[0], item["label"].item())])
return res return res
assert test_config(2, 0) == [0, 2, 4] assert test_config(2, 0) == [0, 2, 4]
@ -222,7 +222,7 @@ def test_sampler_chain():
def test_add_sampler_invalid_input(): def test_add_sampler_invalid_input():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} _ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
data1 = ds.ManifestDataset(manifest_file) data1 = ds.ManifestDataset(manifest_file)
with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:

View File

@ -18,9 +18,8 @@ Testing dataset serialize and deserialize in DE
import filecmp import filecmp
import glob import glob
import json import json
import numpy as np
import os import os
import pytest import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as c import mindspore.dataset.transforms.c_transforms as c
@ -28,6 +27,8 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
from mindspore.dataset.transforms.vision import Inter from mindspore.dataset.transforms.vision import Inter
from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
def test_imagefolder(remove_json_files=True): def test_imagefolder(remove_json_files=True):
""" """
@ -186,7 +187,7 @@ def test_random_crop():
# Serializing into python dictionary # Serializing into python dictionary
ds1_dict = ds.serialize(data1) ds1_dict = ds.serialize(data1)
# Serializing into json object # Serializing into json object
ds1_json = json.dumps(ds1_dict, indent=2) _ = json.dumps(ds1_dict, indent=2)
# Reconstruct dataset pipeline from its serialized form # Reconstruct dataset pipeline from its serialized form
data1_1 = ds.deserialize(input_dict=ds1_dict) data1_1 = ds.deserialize(input_dict=ds1_dict)
@ -198,7 +199,7 @@ def test_random_crop():
for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(), for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
data2.create_dict_iterator()): data2.create_dict_iterator()):
assert np.array_equal(item1['image'], item1_1['image']) assert np.array_equal(item1['image'], item1_1['image'])
image2 = item2["image"] _ = item2["image"]
def validate_jsonfile(filepath): def validate_jsonfile(filepath):
@ -221,10 +222,6 @@ def delete_json_files():
# Test save load minddataset # Test save load minddataset
from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME, FILES_NUM, \
FileWriter, Inter
def test_minddataset(add_and_remove_cv_file): def test_minddataset(add_and_remove_cv_file):
"""tutorial for cv minderdataset.""" """tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
@ -247,7 +244,7 @@ def test_minddataset(add_and_remove_cv_file):
assert ds1_json == ds2_json assert ds1_json == ds2_json
data = get_data(CV_DIR_NAME) _ = get_data(CV_DIR_NAME)
assert data_set.get_dataset_size() == 5 assert data_set.get_dataset_size() == 5
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():

View File

@ -152,7 +152,7 @@ def test_shuffle_exception_01():
data1 = data1.shuffle(buffer_size=-1) data1 = data1.shuffle(buffer_size=-1)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "buffer_size" in str(e)
@ -170,7 +170,7 @@ def test_shuffle_exception_02():
data1 = data1.shuffle(buffer_size=0) data1 = data1.shuffle(buffer_size=0)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "buffer_size" in str(e)
@ -188,7 +188,7 @@ def test_shuffle_exception_03():
data1 = data1.shuffle(buffer_size=1) data1 = data1.shuffle(buffer_size=1)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "buffer_size" in str(e)
@ -206,7 +206,7 @@ def test_shuffle_exception_05():
data1 = data1.shuffle() data1 = data1.shuffle()
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "buffer_size" in str(e)
@ -224,7 +224,7 @@ def test_shuffle_exception_06():
data1 = data1.shuffle(buffer_size=False) data1 = data1.shuffle(buffer_size=False)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "buffer_size" in str(e)
@ -242,7 +242,7 @@ def test_shuffle_exception_07():
data1 = data1.shuffle(buffer_size=True) data1 = data1.shuffle(buffer_size=True)
sum([1 for _ in data1]) sum([1 for _ in data1])
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "buffer_size" in str(e)

View File

@ -70,7 +70,6 @@ def test_skip_1():
buf = [] buf = []
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 0
assert buf == [] assert buf == []

View File

@ -29,47 +29,47 @@ text_file_data = ["This is a text file.", "Another file.", "Be happy every day."
def split_with_invalid_inputs(d): def split_with_invalid_inputs(d):
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([]) _, _ = d.split([])
assert "sizes cannot be empty" in str(info.value) assert "sizes cannot be empty" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([5, 0.6]) _, _ = d.split([5, 0.6])
assert "sizes should be list of int or list of float" in str(info.value) assert "sizes should be list of int or list of float" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([-1, 6]) _, _ = d.split([-1, 6])
assert "there should be no negative numbers" in str(info.value) assert "there should be no negative numbers" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([3, 1]) _, _ = d.split([3, 1])
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([5, 1]) _, _ = d.split([5, 1])
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([-0.5, 0.5]) _, _ = d.split([-0.5, 0.5])
assert "there should be no numbers outside the range [0, 1]" in str(info.value) assert "there should be no numbers outside the range [0, 1]" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([1.5, 0.5]) _, _ = d.split([1.5, 0.5])
assert "there should be no numbers outside the range [0, 1]" in str(info.value) assert "there should be no numbers outside the range [0, 1]" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([0.5, 0.6]) _, _ = d.split([0.5, 0.6])
assert "percentages do not sum up to 1" in str(info.value) assert "percentages do not sum up to 1" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
s1, s2 = d.split([0.3, 0.6]) _, _ = d.split([0.3, 0.6])
assert "percentages do not sum up to 1" in str(info.value) assert "percentages do not sum up to 1" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([0.05, 0.95]) _, _ = d.split([0.05, 0.95])
assert "percentage 0.05 is too small" in str(info.value) assert "percentage 0.05 is too small" in str(info.value)
@ -79,7 +79,7 @@ def test_unmappable_invalid_input():
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([4, 1]) _, _ = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value) assert "dataset should not be sharded before split" in str(info.value)
@ -273,7 +273,7 @@ def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([4, 1]) _, _ = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value) assert "dataset should not be sharded before split" in str(info.value)

View File

@ -28,8 +28,8 @@ class Augment:
def __init__(self, loss): def __init__(self, loss):
self.loss = loss self.loss = loss
def preprocess(self, input): def preprocess(self, input_):
return input return input_
def update(self, data): def update(self, data):
self.loss = data["loss"] self.loss = data["loss"]
@ -143,7 +143,7 @@ def test_multiple_iterators():
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
# 2nd dataset # 2nd dataset
dataset2 = ds.GeneratorDataset(gen, column_names=["input"]) dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0) aug = Augment(0)
@ -175,7 +175,7 @@ def test_sync_exception_01():
try: try:
dataset = dataset.shuffle(shuffle_size) dataset = dataset.shuffle(shuffle_size)
except BaseException as e: except Exception as e:
assert "shuffle" in str(e) assert "shuffle" in str(e)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
@ -197,7 +197,7 @@ def test_sync_exception_02():
try: try:
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
except BaseException as e: except Exception as e:
assert "name" in str(e) assert "name" in str(e)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)

View File

@ -46,7 +46,7 @@ def test_take_01():
data1 = data1.take(1) data1 = data1.take(1)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for _, d in enumerate(data1): for _, d in enumerate(data1):
assert d[0][0] == 0 assert d[0][0] == 0
@ -63,7 +63,7 @@ def test_take_02():
data1 = data1.take(2) data1 = data1.take(2)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data element
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i % 2 == d[0][0] assert i % 2 == d[0][0]
@ -80,7 +80,7 @@ def test_take_03():
data1 = data1.take(3) data1 = data1.take(3)
data1 = data1.repeat(2) data1 = data1.repeat(2)
# Here i refers to index, d refers to data element # Here i refers to index, d refers to data elements
for i, d in enumerate(data1): for i, d in enumerate(data1):
assert i % 3 == d[0][0] assert i % 3 == d[0][0]

View File

@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import mindspore._c_dataengine as cde
import numpy as np
import pytest import pytest
import numpy as np
from mindspore.dataset.text import to_str, to_bytes
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.dataset.text import to_str
# pylint: disable=comparison-with-itself # pylint: disable=comparison-with-itself
def test_basic(): def test_basic():
@ -34,7 +32,7 @@ def compare(strings):
arr = np.array(strings, dtype='S') arr = np.array(strings, dtype='S')
def gen(): def gen():
yield arr, (yield arr,)
data = ds.GeneratorDataset(gen, column_names=["col"]) data = ds.GeneratorDataset(gen, column_names=["col"])
@ -50,7 +48,7 @@ def test_generator():
def test_batching_strings(): def test_batching_strings():
def gen(): def gen():
yield np.array(["ab", "cde", "121"], dtype='S'), yield (np.array(["ab", "cde", "121"], dtype='S'),)
data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10) data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
@ -62,7 +60,7 @@ def test_batching_strings():
def test_map(): def test_map():
def gen(): def gen():
yield np.array(["ab cde 121"], dtype='S'), yield (np.array(["ab cde 121"], dtype='S'),)
data = ds.GeneratorDataset(gen, column_names=["col"]) data = ds.GeneratorDataset(gen, column_names=["col"])
@ -79,7 +77,7 @@ def test_map():
def test_map2(): def test_map2():
def gen(): def gen():
yield np.array(["ab cde 121"], dtype='S'), yield (np.array(["ab cde 121"], dtype='S'),)
data = ds.GeneratorDataset(gen, column_names=["col"]) data = ds.GeneratorDataset(gen, column_names=["col"])

View File

@ -215,7 +215,7 @@ def test_case_tf_file_no_schema_columns_list():
assert row["col_sint16"] == [-32768] assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info: with pytest.raises(KeyError) as info:
a = row["col_sint32"] _ = row["col_sint32"]
assert "col_sint32" in str(info.value) assert "col_sint32" in str(info.value)
@ -234,7 +234,7 @@ def test_tf_record_schema_columns_list():
assert row["col_sint16"] == [-32768] assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info: with pytest.raises(KeyError) as info:
a = row["col_sint32"] _ = row["col_sint32"]
assert "col_sint32" in str(info.value) assert "col_sint32" in str(info.value)
@ -246,7 +246,7 @@ def test_case_invalid_files():
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
row = data.create_dict_iterator().get_next() _ = data.create_dict_iterator().get_next()
assert "cannot be opened" in str(info.value) assert "cannot be opened" in str(info.value)
assert "not valid tfrecord files" in str(info.value) assert "not valid tfrecord files" in str(info.value)
assert valid_file not in str(info.value) assert valid_file not in str(info.value)

View File

@ -123,7 +123,7 @@ def test_to_type_03():
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Numpy" in str(e) assert "Numpy" in str(e)
@ -145,7 +145,7 @@ def test_to_type_04():
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "missing" in str(e) assert "missing" in str(e)
@ -167,7 +167,7 @@ def test_to_type_05():
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "data type" in str(e) assert "data type" in str(e)

View File

@ -59,7 +59,7 @@ def test_uniform_augment(plot=False, num_ops=2):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -87,7 +87,7 @@ def test_uniform_augment(plot=False, num_ops=2):
ds_ua = ds_ua.batch(512) ds_ua = ds_ua.batch(512)
for idx, (image, label) in enumerate(ds_ua): for idx, (image, _) in enumerate(ds_ua):
if idx == 0: if idx == 0:
images_ua = np.transpose(image, (0, 2, 3, 1)) images_ua = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -122,7 +122,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
ds_original = ds_original.batch(512) ds_original = ds_original.batch(512)
for idx, (image, label) in enumerate(ds_original): for idx, (image, _) in enumerate(ds_original):
if idx == 0: if idx == 0:
images_original = np.transpose(image, (0, 2, 3, 1)) images_original = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -149,7 +149,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
ds_ua = ds_ua.batch(512) ds_ua = ds_ua.batch(512)
for idx, (image, label) in enumerate(ds_ua): for idx, (image, _) in enumerate(ds_ua):
if idx == 0: if idx == 0:
images_ua = np.transpose(image, (0, 2, 3, 1)) images_ua = np.transpose(image, (0, 2, 3, 1))
else: else:
@ -180,9 +180,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
F.Invert()] F.Invert()]
try: try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "operations" in str(e) assert "operations" in str(e)
@ -200,9 +200,9 @@ def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
C.RandomRotation(degrees=45)] C.RandomRotation(degrees=45)]
try: try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e) assert "num_ops" in str(e)
@ -220,9 +220,9 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
C.RandomRotation(degrees=45)] C.RandomRotation(degrees=45)]
try: try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e) assert "num_ops" in str(e)
@ -239,9 +239,9 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
C.RandomRotation(degrees=45)] C.RandomRotation(degrees=45)]
try: try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "integer" in str(e) assert "integer" in str(e)
@ -250,7 +250,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
Test UniformAugment with greater crop size Test UniformAugment with greater crop size
""" """
logger.info("Test CPP UniformAugment with random_crop bad input") logger.info("Test CPP UniformAugment with random_crop bad input")
batch_size=2 batch_size = 2
cifar10_dir = "../data/dataset/testCifar10Data" cifar10_dir = "../data/dataset/testCifar10Data"
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3] ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
@ -266,9 +266,9 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1) ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
num_batches = 0 num_batches = 0
try: try:
for data in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
num_batches += 1 num_batches += 1
except BaseException as e: except Exception as e:
assert "Crop size" in str(e) assert "Crop size" in str(e)

View File

@ -75,6 +75,7 @@ def test_variable_size_batch():
return batchInfo.get_epoch_num() + 1 return batchInfo.get_epoch_num() + 1
def simple_copy(colList, batchInfo): def simple_copy(colList, batchInfo):
_ = batchInfo
return ([np.copy(arr) for arr in colList],) return ([np.copy(arr) for arr in colList],)
def test_repeat_batch(gen_num, r, drop, func, res): def test_repeat_batch(gen_num, r, drop, func, res):
@ -186,6 +187,7 @@ def test_batch_multi_col_map():
yield (np.array([i]), np.array([i ** 2])) yield (np.array([i]), np.array([i ** 2]))
def col1_col2_add_num(col1, col2, batchInfo): def col1_col2_add_num(col1, col2, batchInfo):
_ = batchInfo
return ([[np.copy(arr + 100) for arr in col1], return ([[np.copy(arr + 100) for arr in col1],
[np.copy(arr + 300) for arr in col2]]) [np.copy(arr + 300) for arr in col2]])
@ -287,11 +289,11 @@ def test_exception():
def bad_batch_size(batchInfo): def bad_batch_size(batchInfo):
raise StopIteration raise StopIteration
return batchInfo.get_batch_num() #return batchInfo.get_batch_num()
def bad_map_func(col, batchInfo): def bad_map_func(col, batchInfo):
raise StopIteration raise StopIteration
return (col,) #return (col,)
data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size) data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size)
try: try:

View File

@ -143,7 +143,7 @@ def test_zip_exception_01():
num_iter += 1 num_iter += 1
logger.info("Number of data in zipped dataz: {}".format(num_iter)) logger.info("Number of data in zipped dataz: {}".format(num_iter))
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
@ -164,7 +164,7 @@ def test_zip_exception_02():
num_iter += 1 num_iter += 1
logger.info("Number of data in zipped dataz: {}".format(num_iter)) logger.info("Number of data in zipped dataz: {}".format(num_iter))
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
@ -185,7 +185,7 @@ def test_zip_exception_03():
num_iter += 1 num_iter += 1
logger.info("Number of data in zipped dataz: {}".format(num_iter)) logger.info("Number of data in zipped dataz: {}".format(num_iter))
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
@ -205,7 +205,7 @@ def test_zip_exception_04():
num_iter += 1 num_iter += 1
logger.info("Number of data in zipped dataz: {}".format(num_iter)) logger.info("Number of data in zipped dataz: {}".format(num_iter))
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
@ -226,7 +226,7 @@ def test_zip_exception_05():
num_iter += 1 num_iter += 1
logger.info("Number of data in zipped dataz: {}".format(num_iter)) logger.info("Number of data in zipped dataz: {}".format(num_iter))
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
@ -246,7 +246,7 @@ def test_zip_exception_06():
num_iter += 1 num_iter += 1
logger.info("Number of data in zipped dataz: {}".format(num_iter)) logger.info("Number of data in zipped dataz: {}".format(num_iter))
except BaseException as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))

View File

@ -300,16 +300,16 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
info = reader.read_category_info() info = reader.read_category_info()
logger.info("category info: {}".format(info)) logger.info("category info: {}".format(info))
with pytest.raises(ParamValueError) as err: with pytest.raises(ParamValueError):
reader.read_at_page_by_id(0, "0", 1) reader.read_at_page_by_id(0, "0", 1)
with pytest.raises(ParamValueError) as err: with pytest.raises(ParamValueError):
reader.read_at_page_by_id(0, 0, "b") reader.read_at_page_by_id(0, 0, "b")
with pytest.raises(ParamValueError) as err: with pytest.raises(ParamValueError):
reader.read_at_page_by_name("822", "e", 1) reader.read_at_page_by_name("822", "e", 1)
with pytest.raises(ParamValueError) as err: with pytest.raises(ParamValueError):
reader.read_at_page_by_name("822", 0, "qwer") reader.read_at_page_by_name("822", 0, "qwer")
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
@ -330,14 +330,14 @@ def test_mindpage_filename_not_exist(fixture_cv_file):
info = reader.read_category_info() info = reader.read_category_info()
logger.info("category info: {}".format(info)) logger.info("category info: {}".format(info))
with pytest.raises(MRMFetchDataError) as err: with pytest.raises(MRMFetchDataError):
reader.read_at_page_by_id(9999, 0, 1) reader.read_at_page_by_id(9999, 0, 1)
with pytest.raises(MRMFetchDataError) as err: with pytest.raises(MRMFetchDataError):
reader.read_at_page_by_name("abc.jpg", 0, 1) reader.read_at_page_by_name("abc.jpg", 0, 1)
with pytest.raises(ParamValueError) as err: with pytest.raises(ParamValueError):
reader.read_at_page_by_name(1, 0, 1) reader.read_at_page_by_name(1, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) _ = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]

View File

@ -14,10 +14,9 @@
"""test mnist to mindrecord tool""" """test mnist to mindrecord tool"""
import gzip import gzip
import os import os
import numpy as np
import cv2
import pytest import pytest
import numpy as np
import cv2
from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import FileReader from mindspore.mindrecord import FileReader

View File

@ -14,12 +14,12 @@
# ============================================================================ # ============================================================================
"""utils for test""" """utils for test"""
import collections
import json
import numpy as np
import os import os
import re import re
import string import string
import collections
import json
import numpy as np
from mindspore import log as logger from mindspore import log as logger
@ -185,7 +185,7 @@ def get_nlp_data(dir_name, vocab_file, num):
""" """
if not os.path.isdir(dir_name): if not os.path.isdir(dir_name):
raise IOError("Directory {} not exists".format(dir_name)) raise IOError("Directory {} not exists".format(dir_name))
for root, dirs, files in os.walk(dir_name): for root, _, files in os.walk(dir_name):
for index, file_name_extension in enumerate(files): for index, file_name_extension in enumerate(files):
if index < num: if index < num:
file_path = os.path.join(root, file_name_extension) file_path = os.path.join(root, file_name_extension)