!1478 [Dataset] clean pylint.
This commit is contained in:
parent
c086d91aaf
commit
9b2a778d94
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test dataset performance about mindspore.MindDataset, mindspore.TFRecordDataset, tf.data.TFRecordDataset"""
|
||||
import tensorflow as tf
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.mindrecord import FileReader
|
||||
|
|
|
@ -32,9 +32,9 @@ def test_apply_generator_case():
|
|||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data2 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
ds = ds.repeat(2)
|
||||
return ds.batch(4)
|
||||
def dataset_fn(ds_):
|
||||
ds_ = ds_.repeat(2)
|
||||
return ds_.batch(4)
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
data2 = data2.repeat(2)
|
||||
|
@ -52,11 +52,11 @@ def test_apply_imagefolder_case():
|
|||
decode_op = vision.Decode()
|
||||
normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0])
|
||||
|
||||
def dataset_fn(ds):
|
||||
ds = ds.map(operations=decode_op)
|
||||
ds = ds.map(operations=normalize_op)
|
||||
ds = ds.repeat(2)
|
||||
return ds
|
||||
def dataset_fn(ds_):
|
||||
ds_ = ds_.map(operations=decode_op)
|
||||
ds_ = ds_.map(operations=normalize_op)
|
||||
ds_ = ds_.repeat(2)
|
||||
return ds_
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
data2 = data2.map(operations=decode_op)
|
||||
|
@ -67,125 +67,125 @@ def test_apply_imagefolder_case():
|
|||
assert np.array_equal(item1["image"], item2["image"])
|
||||
|
||||
|
||||
def test_apply_flow_case_0(id=0):
|
||||
def test_apply_flow_case_0(id_=0):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
def dataset_fn(ds_):
|
||||
if id_ == 0:
|
||||
ds_ = ds_.batch(4)
|
||||
elif id_ == 1:
|
||||
ds_ = ds_.repeat(2)
|
||||
elif id_ == 2:
|
||||
ds_ = ds_.batch(4)
|
||||
ds_ = ds_.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
ds_ = ds_.shuffle(buffer_size=4)
|
||||
return ds_
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
if id_ == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
elif id_ == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
elif id_ == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
|
||||
def test_apply_flow_case_1(id=1):
|
||||
def test_apply_flow_case_1(id_=1):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
def dataset_fn(ds_):
|
||||
if id_ == 0:
|
||||
ds_ = ds_.batch(4)
|
||||
elif id_ == 1:
|
||||
ds_ = ds_.repeat(2)
|
||||
elif id_ == 2:
|
||||
ds_ = ds_.batch(4)
|
||||
ds_ = ds_.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
ds_ = ds_.shuffle(buffer_size=4)
|
||||
return ds_
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
if id_ == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
elif id_ == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
elif id_ == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
|
||||
def test_apply_flow_case_2(id=2):
|
||||
def test_apply_flow_case_2(id_=2):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
def dataset_fn(ds_):
|
||||
if id_ == 0:
|
||||
ds_ = ds_.batch(4)
|
||||
elif id_ == 1:
|
||||
ds_ = ds_.repeat(2)
|
||||
elif id_ == 2:
|
||||
ds_ = ds_.batch(4)
|
||||
ds_ = ds_.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
ds_ = ds_.shuffle(buffer_size=4)
|
||||
return ds_
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
if id_ == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
elif id_ == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
elif id_ == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
||||
|
||||
def test_apply_flow_case_3(id=3):
|
||||
def test_apply_flow_case_3(id_=3):
|
||||
# apply control flow operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
if id == 0:
|
||||
ds = ds.batch(4)
|
||||
elif id == 1:
|
||||
ds = ds.repeat(2)
|
||||
elif id == 2:
|
||||
ds = ds.batch(4)
|
||||
ds = ds.repeat(2)
|
||||
def dataset_fn(ds_):
|
||||
if id_ == 0:
|
||||
ds_ = ds_.batch(4)
|
||||
elif id_ == 1:
|
||||
ds_ = ds_.repeat(2)
|
||||
elif id_ == 2:
|
||||
ds_ = ds_.batch(4)
|
||||
ds_ = ds_.repeat(2)
|
||||
else:
|
||||
ds = ds.shuffle(buffer_size=4)
|
||||
return ds
|
||||
ds_ = ds_.shuffle(buffer_size=4)
|
||||
return ds_
|
||||
|
||||
data1 = data1.apply(dataset_fn)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter = num_iter + 1
|
||||
|
||||
if id == 0:
|
||||
if id_ == 0:
|
||||
assert num_iter == 16
|
||||
elif id == 1:
|
||||
elif id_ == 1:
|
||||
assert num_iter == 128
|
||||
elif id == 2:
|
||||
elif id_ == 2:
|
||||
assert num_iter == 32
|
||||
else:
|
||||
assert num_iter == 64
|
||||
|
@ -195,11 +195,11 @@ def test_apply_exception_case():
|
|||
# apply exception operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
def dataset_fn(ds):
|
||||
ds = ds.repeat(2)
|
||||
return ds.batch(4)
|
||||
def dataset_fn(ds_):
|
||||
ds_ = ds_.repeat(2)
|
||||
return ds_.batch(4)
|
||||
|
||||
def exception_fn(ds):
|
||||
def exception_fn():
|
||||
return np.array([[0], [1], [3], [4], [5]])
|
||||
|
||||
try:
|
||||
|
@ -220,12 +220,12 @@ def test_apply_exception_case():
|
|||
|
||||
try:
|
||||
data2 = data1.apply(dataset_fn)
|
||||
data3 = data1.apply(dataset_fn)
|
||||
_ = data1.apply(dataset_fn)
|
||||
for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
pass
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_auto_contrast(plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -79,7 +79,7 @@ def test_auto_contrast(plot=False):
|
|||
|
||||
ds_auto_contrast = ds_auto_contrast.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_auto_contrast):
|
||||
for idx, (image, _) in enumerate(ds_auto_contrast):
|
||||
if idx == 0:
|
||||
images_auto_contrast = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
|
|
@ -273,7 +273,7 @@ def test_batch_exception_01():
|
|||
data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_parallel_workers" in str(e)
|
||||
|
||||
|
@ -290,7 +290,7 @@ def test_batch_exception_02():
|
|||
data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_parallel_workers" in str(e)
|
||||
|
||||
|
@ -307,7 +307,7 @@ def test_batch_exception_03():
|
|||
data1 = data1.batch(batch_size=0)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "batch_size" in str(e)
|
||||
|
||||
|
@ -324,7 +324,7 @@ def test_batch_exception_04():
|
|||
data1 = data1.batch(batch_size=-1)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "batch_size" in str(e)
|
||||
|
||||
|
@ -341,7 +341,7 @@ def test_batch_exception_05():
|
|||
data1 = data1.batch(batch_size=False)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "batch_size" in str(e)
|
||||
|
||||
|
@ -358,7 +358,7 @@ def test_batch_exception_07():
|
|||
data1 = data1.batch(3, drop_remainder=0)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "drop_remainder" in str(e)
|
||||
|
||||
|
@ -375,7 +375,7 @@ def test_batch_exception_08():
|
|||
data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_parallel_workers" in str(e)
|
||||
|
||||
|
@ -392,7 +392,7 @@ def test_batch_exception_09():
|
|||
data1 = data1.batch(drop_remainder=True, num_parallel_workers=4)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "batch_size" in str(e)
|
||||
|
||||
|
@ -409,7 +409,7 @@ def test_batch_exception_10():
|
|||
data1 = data1.batch(batch_size=4, num_parallel_workers=8192)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_parallel_workers" in str(e)
|
||||
|
||||
|
@ -429,7 +429,7 @@ def test_batch_exception_11():
|
|||
data1 = data1.batch(batch_size, num_parallel_workers)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "drop_remainder" in str(e)
|
||||
|
||||
|
@ -450,7 +450,7 @@ def test_batch_exception_12():
|
|||
data1 = data1.batch(drop_remainder, batch_size=batch_size)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "batch_size" in str(e)
|
||||
|
||||
|
@ -469,7 +469,7 @@ def test_batch_exception_13():
|
|||
data1 = data1.batch(batch_size, shard_id=1)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "shard_id" in str(e)
|
||||
|
||||
|
|
|
@ -24,18 +24,18 @@ from mindspore import log as logger
|
|||
# In generator dataset: Number of rows is 3; its values are 0, 1, 2
|
||||
def generator():
|
||||
for i in range(3):
|
||||
yield np.array([i]),
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
# In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9
|
||||
def generator_10():
|
||||
for i in range(3, 10):
|
||||
yield np.array([i]),
|
||||
yield (np.array([i]),)
|
||||
|
||||
# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
|
||||
def generator_20():
|
||||
for i in range(10, 20):
|
||||
yield np.array([i]),
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def test_concat_01():
|
||||
|
@ -85,7 +85,7 @@ def test_concat_03():
|
|||
data3 = data1 + data2
|
||||
|
||||
try:
|
||||
for i, d in enumerate(data3):
|
||||
for _, _ in enumerate(data3):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError:
|
||||
|
@ -104,7 +104,7 @@ def test_concat_04():
|
|||
data3 = data1 + data2
|
||||
|
||||
try:
|
||||
for i, d in enumerate(data3):
|
||||
for _, _ in enumerate(data3):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError:
|
||||
|
@ -125,7 +125,7 @@ def test_concat_05():
|
|||
data3 = data1 + data2
|
||||
|
||||
try:
|
||||
for i, d in enumerate(data3):
|
||||
for _, _ in enumerate(data3):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError:
|
||||
|
|
|
@ -31,7 +31,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
|||
|
||||
def test_basic():
|
||||
"""
|
||||
Test basic configuration functions
|
||||
Test basic configuration functions
|
||||
"""
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
|
@ -138,7 +138,7 @@ def test_deterministic_run_fail():
|
|||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
# two datasets split the number out of the sequence a
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
@ -157,7 +157,7 @@ def test_deterministic_run_pass():
|
|||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
|
@ -179,7 +179,7 @@ def test_deterministic_run_pass():
|
|||
try:
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
# two datasets both use numbers from the generated sequence "a"
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
@ -344,7 +344,7 @@ def test_deterministic_python_seed_multi_thread():
|
|||
|
||||
try:
|
||||
np.testing.assert_equal(data1_output, data2_output)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
# expect output to not match during multi-threaded excution
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
|
|
@ -107,14 +107,20 @@ def test_tfrecord_shardings4(print_res=False):
|
|||
assert len(result_list) == expect_length
|
||||
assert set(result_list) == expect_set
|
||||
|
||||
check_result(sharding_config(2, 0, None, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
|
||||
check_result(sharding_config(2, 1, None, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
|
||||
check_result(sharding_config(2, 0, None, 1), 20,
|
||||
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
|
||||
check_result(sharding_config(2, 1, None, 1), 20,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
|
||||
check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21})
|
||||
check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31})
|
||||
check_result(sharding_config(2, 0, 40, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
|
||||
check_result(sharding_config(2, 1, 40, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
|
||||
check_result(sharding_config(2, 0, 55, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
|
||||
check_result(sharding_config(2, 1, 55, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
|
||||
check_result(sharding_config(2, 0, 40, 1), 20,
|
||||
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
|
||||
check_result(sharding_config(2, 1, 40, 1), 20,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
|
||||
check_result(sharding_config(2, 0, 55, 1), 20,
|
||||
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
|
||||
check_result(sharding_config(2, 1, 55, 1), 20,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
|
||||
check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31})
|
||||
check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8})
|
||||
check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28})
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_textline_dataset_totext():
|
|||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert (count == 5)
|
||||
assert count == 5
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
|
|
@ -24,10 +24,10 @@ def test_voc_segmentation():
|
|||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
num = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
assert (item["image"].shape[0] == IMAGE_SHAPE[num])
|
||||
assert (item["target"].shape[0] == TARGET_SHAPE[num])
|
||||
assert item["image"].shape[0] == IMAGE_SHAPE[num]
|
||||
assert item["target"].shape[0] == TARGET_SHAPE[num]
|
||||
num += 1
|
||||
assert (num == 10)
|
||||
assert num == 10
|
||||
|
||||
|
||||
def test_voc_detection():
|
||||
|
@ -35,12 +35,12 @@ def test_voc_detection():
|
|||
num = 0
|
||||
count = [0, 0, 0, 0, 0, 0]
|
||||
for item in data1.create_dict_iterator():
|
||||
assert (item["image"].shape[0] == IMAGE_SHAPE[num])
|
||||
assert item["image"].shape[0] == IMAGE_SHAPE[num]
|
||||
for bbox in item["annotation"]:
|
||||
count[bbox[0]] += 1
|
||||
num += 1
|
||||
assert (num == 9)
|
||||
assert (count == [3, 2, 1, 2, 4, 3])
|
||||
assert num == 9
|
||||
assert count == [3, 2, 1, 2, 4, 3]
|
||||
|
||||
|
||||
def test_voc_class_index():
|
||||
|
@ -58,8 +58,8 @@ def test_voc_class_index():
|
|||
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5)
|
||||
count[bbox[0]] += 1
|
||||
num += 1
|
||||
assert (num == 6)
|
||||
assert (count == [3, 2, 0, 0, 0, 3])
|
||||
assert num == 6
|
||||
assert count == [3, 2, 0, 0, 0, 3]
|
||||
|
||||
|
||||
def test_voc_get_class_indexing():
|
||||
|
@ -76,8 +76,8 @@ def test_voc_get_class_indexing():
|
|||
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5)
|
||||
count[bbox[0]] += 1
|
||||
num += 1
|
||||
assert (num == 9)
|
||||
assert (count == [3, 2, 1, 2, 4, 3])
|
||||
assert num == 9
|
||||
assert count == [3, 2, 1, 2, 4, 3]
|
||||
|
||||
|
||||
def test_case_0():
|
||||
|
@ -93,9 +93,9 @@ def test_case_0():
|
|||
data1 = data1.batch(batch_size, drop_remainder=True)
|
||||
|
||||
num = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num += 1
|
||||
assert (num == 20)
|
||||
assert num == 20
|
||||
|
||||
|
||||
def test_case_1():
|
||||
|
@ -110,9 +110,9 @@ def test_case_1():
|
|||
data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
|
||||
|
||||
num = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
for _ in data1.create_dict_iterator():
|
||||
num += 1
|
||||
assert (num == 18)
|
||||
assert num == 18
|
||||
|
||||
|
||||
def test_voc_exception():
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_equalize(plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -79,7 +79,7 @@ def test_equalize(plot=False):
|
|||
|
||||
ds_equalize = ds_equalize.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_equalize):
|
||||
for idx, (image, _) in enumerate(ds_equalize):
|
||||
if idx == 0:
|
||||
images_equalize = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.c_transforms as cde
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
|
@ -31,7 +29,6 @@ def test_diff_predicate_func():
|
|||
cde.Decode(),
|
||||
cde.Resize([64, 64])
|
||||
]
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False)
|
||||
dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1)
|
||||
dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4)
|
||||
|
@ -40,7 +37,6 @@ def test_diff_predicate_func():
|
|||
label_list = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ori_img = data["image"]
|
||||
label = data["label"]
|
||||
label_list.append(label)
|
||||
assert num_iter == 1
|
||||
|
@ -200,6 +196,7 @@ def generator_1d_zip2():
|
|||
|
||||
|
||||
def filter_func_zip(data1, data2):
|
||||
_ = data2
|
||||
if data1 > 20:
|
||||
return False
|
||||
return True
|
||||
|
@ -249,6 +246,7 @@ def test_filter_by_generator_with_zip_after():
|
|||
|
||||
|
||||
def filter_func_map(col1, col2):
|
||||
_ = col2
|
||||
if col1[0] > 8:
|
||||
return True
|
||||
return False
|
||||
|
@ -262,6 +260,7 @@ def filter_func_map_part(col1):
|
|||
|
||||
|
||||
def filter_func_map_all(col1, col2):
|
||||
_, _ = col1, col2
|
||||
return True
|
||||
|
||||
|
||||
|
@ -334,6 +333,7 @@ def test_filter_by_generator_with_rename():
|
|||
|
||||
# test input_column
|
||||
def filter_func_input_column1(col1, col2):
|
||||
_ = col2
|
||||
if col1[0] < 8:
|
||||
return True
|
||||
return False
|
||||
|
@ -346,6 +346,7 @@ def filter_func_input_column2(col1):
|
|||
|
||||
|
||||
def filter_func_input_column3(col1):
|
||||
_ = col1
|
||||
return True
|
||||
|
||||
|
||||
|
@ -380,6 +381,7 @@ def generator_mc_p1(maxid=20):
|
|||
|
||||
|
||||
def filter_func_Partial_0(col1, col2, col3, col4):
|
||||
_, _, _ = col2, col3, col4
|
||||
filter_data = [0, 1, 2, 3, 4, 11]
|
||||
if col1[0] in filter_data:
|
||||
return False
|
||||
|
@ -439,6 +441,7 @@ def test_filter_by_generator_Partial2():
|
|||
|
||||
|
||||
def filter_func_Partial(col1, col2):
|
||||
_ = col2
|
||||
if col1[0] % 3 == 0:
|
||||
return True
|
||||
return False
|
||||
|
@ -461,6 +464,7 @@ def test_filter_by_generator_Partial():
|
|||
|
||||
|
||||
def filter_func_cifar(col1, col2):
|
||||
_ = col1
|
||||
if col2 % 3 == 0:
|
||||
return True
|
||||
return False
|
||||
|
@ -490,6 +494,7 @@ def generator_sort2(maxid=20):
|
|||
|
||||
|
||||
def filter_func_part_sort(col1, col2, col3, col4, col5, col6):
|
||||
_, _, _, _, _, _ = col1, col2, col3, col4, col5, col6
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_invert(plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -79,7 +79,7 @@ def test_invert(plot=False):
|
|||
|
||||
ds_invert = ds_invert.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_invert):
|
||||
for idx, (image, _) in enumerate(ds_invert):
|
||||
if idx == 0:
|
||||
images_invert = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
|
|
@ -17,11 +17,11 @@ This is the test module for mindrecord
|
|||
"""
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import re
|
||||
import string
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
@ -46,9 +46,10 @@ def add_and_remove_cv_file():
|
|||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists(
|
||||
"{}.db".format(x)) else None
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
|
@ -117,7 +118,9 @@ def add_and_remove_nlp_compress_file():
|
|||
255, 256, -32768, 32767, -32769, 32768, -2147483648,
|
||||
2147483647], dtype=np.int32), [-1]),
|
||||
"array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
|
||||
256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]),
|
||||
256, -32768, 32767, -32769, 32768,
|
||||
-2147483648, 2147483647, -2147483649, 2147483649,
|
||||
-922337036854775808, 9223372036854775807]), [1, -1]),
|
||||
"array_c": str.encode("nlp data"),
|
||||
"array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
|
||||
})
|
||||
|
@ -151,7 +154,9 @@ def test_nlp_compress_data(add_and_remove_nlp_compress_file):
|
|||
255, 256, -32768, 32767, -32769, 32768, -2147483648,
|
||||
2147483647], dtype=np.int32), [-1]),
|
||||
"array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
|
||||
256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]),
|
||||
256, -32768, 32767, -32769, 32768,
|
||||
-2147483648, 2147483647, -2147483649, 2147483649,
|
||||
-922337036854775808, 9223372036854775807]), [1, -1]),
|
||||
"array_c": str.encode("nlp data"),
|
||||
"array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
|
||||
})
|
||||
|
@ -194,9 +199,10 @@ def test_cv_minddataset_writer_tutorial():
|
|||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists(
|
||||
"{}.db".format(x)) else None
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
|
||||
|
@ -478,9 +484,10 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
|
|||
paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
|
||||
os.remove("{}.db".format(x)) if os.path.exists(
|
||||
"{}.db".format(x)) else None
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
|
@ -779,7 +786,7 @@ def get_nlp_data(dir_name, vocab_file, num):
|
|||
"""
|
||||
if not os.path.isdir(dir_name):
|
||||
raise IOError("Directory {} not exists".format(dir_name))
|
||||
for root, dirs, files in os.walk(dir_name):
|
||||
for root, _, files in os.walk(dir_name):
|
||||
for index, file_name_extension in enumerate(files):
|
||||
if index < num:
|
||||
file_path = os.path.join(root, file_name_extension)
|
||||
|
@ -851,7 +858,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
|||
if os.path.exists("{}".format(mindrecord_file_name)):
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
if os.path.exists("{}.db".format(mindrecord_file_name)):
|
||||
os.remove("{}.db".format(x))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
data = [{"file_name": "001.jpg", "label": 4,
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
|
|
|
@ -26,8 +26,10 @@ CV1_FILE_NAME = "./imagenet1.mindrecord"
|
|||
|
||||
def create_cv_mindrecord(files_num):
|
||||
"""tutorial for cv dataset writer."""
|
||||
os.remove(CV_FILE_NAME) if os.path.exists(CV_FILE_NAME) else None
|
||||
os.remove("{}.db".format(CV_FILE_NAME)) if os.path.exists("{}.db".format(CV_FILE_NAME)) else None
|
||||
if os.path.exists(CV_FILE_NAME):
|
||||
os.remove(CV_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
writer = FileWriter(CV_FILE_NAME, files_num)
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
|
@ -39,8 +41,10 @@ def create_cv_mindrecord(files_num):
|
|||
|
||||
def create_diff_schema_cv_mindrecord(files_num):
|
||||
"""tutorial for cv dataset writer."""
|
||||
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
|
||||
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
|
||||
if os.path.exists(CV1_FILE_NAME):
|
||||
os.remove(CV1_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
|
@ -52,8 +56,10 @@ def create_diff_schema_cv_mindrecord(files_num):
|
|||
|
||||
def create_diff_page_size_cv_mindrecord(files_num):
|
||||
"""tutorial for cv dataset writer."""
|
||||
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
|
||||
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
|
||||
if os.path.exists(CV1_FILE_NAME):
|
||||
os.remove(CV1_FILE_NAME)
|
||||
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||
writer.set_page_size(1 << 26) # 64MB
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
|
@ -69,8 +75,8 @@ def test_cv_lack_json():
|
|||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception) as err:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers)
|
||||
with pytest.raises(Exception):
|
||||
ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers)
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
@ -80,7 +86,7 @@ def test_cv_lack_mindrecord():
|
|||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="does not exist or permission denied"):
|
||||
data_set = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers)
|
||||
_ = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers)
|
||||
|
||||
|
||||
def test_invalid_mindrecord():
|
||||
|
@ -134,7 +140,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
|
|||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
|
||||
sampler=sampler, shuffle=False)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -149,7 +155,7 @@ def test_cv_minddataset_reader_different_schema():
|
|||
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -166,7 +172,7 @@ def test_cv_minddataset_reader_different_page_size():
|
|||
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -181,7 +187,7 @@ def test_minddataset_invalidate_num_shards():
|
|||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -194,7 +200,7 @@ def test_minddataset_invalidate_shard_id():
|
|||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -207,13 +213,13 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
|||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
|
|
@ -50,7 +50,7 @@ def test_cv_minddataset_reader_multi_image_and_ndarray_tutorial():
|
|||
assert os.path.exists(CV_FILE_NAME)
|
||||
assert os.path.exists(CV_FILE_NAME + ".db")
|
||||
|
||||
"""tutorial for minderdataset."""
|
||||
# tutorial for minderdataset.
|
||||
columns_list = ["id", "image_0", "image_2", "image_3", "image_4", "input_mask", "segments"]
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
|
||||
|
|
|
@ -20,7 +20,6 @@ import pytest
|
|||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
from mindspore.dataset.text import to_str
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_on_tokenized_line():
|
|||
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
||||
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
|
||||
for i, d in enumerate(data.create_dict_iterator()):
|
||||
np.testing.assert_array_equal(d["text"], res[i]), i
|
||||
_ = (np.testing.assert_array_equal(d["text"], res[i]), i)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -199,7 +199,7 @@ def test_jieba_5():
|
|||
|
||||
def gen():
|
||||
text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S')
|
||||
yield text,
|
||||
yield (text,)
|
||||
|
||||
|
||||
def pytoken_op(input_data):
|
||||
|
|
|
@ -109,10 +109,9 @@ def test_decode_op():
|
|||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
|
||||
num_iter = 0
|
||||
image = None
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("Looping inside iterator {}".format(num_iter))
|
||||
image = item["image"]
|
||||
_ = item["image"]
|
||||
# plt.subplot(131)
|
||||
# plt.imshow(image)
|
||||
# plt.title("DE image")
|
||||
|
@ -134,10 +133,9 @@ def test_decode_normalize_op():
|
|||
data1 = data1.map(input_columns=["image"], operations=[decode_op, normalize_op])
|
||||
|
||||
num_iter = 0
|
||||
image = None
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("Looping inside iterator {}".format(num_iter))
|
||||
image = item["image"]
|
||||
_ = item["image"]
|
||||
# plt.subplot(131)
|
||||
# plt.imshow(image)
|
||||
# plt.title("DE image")
|
||||
|
|
|
@ -37,8 +37,7 @@ def test_case_0():
|
|||
|
||||
data1 = data1.batch(2)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
for _ in data1.create_dict_iterator(): # each data is a dictionary
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ def test_pad_op():
|
|||
# pylint: disable=unnecessary-lambda
|
||||
def test_pad_grayscale():
|
||||
"""
|
||||
Tests that the pad works for grayscale images
|
||||
Tests that the pad works for grayscale images
|
||||
"""
|
||||
|
||||
def channel_swap(image):
|
||||
|
@ -92,7 +92,7 @@ def test_pad_grayscale():
|
|||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data1 = data1.map(input_columns=["image"], operations=transform())
|
||||
|
||||
# if input is grayscale, the output dimensions should be single channel
|
||||
# if input is grayscale, the output dimensions should be single channel
|
||||
pad_gray = c_vision.Pad(100, fill_value=(20, 20, 20))
|
||||
data1 = data1.map(input_columns=["image"], operations=pad_gray)
|
||||
dataset_shape_1 = []
|
||||
|
@ -100,11 +100,11 @@ def test_pad_grayscale():
|
|||
c_image = item1["image"]
|
||||
dataset_shape_1.append(c_image.shape)
|
||||
|
||||
# Dataset for comparison
|
||||
# Dataset for comparison
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = c_vision.Decode()
|
||||
|
||||
# we use the same padding logic
|
||||
# we use the same padding logic
|
||||
ctrans = [decode_op, pad_gray]
|
||||
dataset_shape_2 = []
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ def batch_padding_performance_3d():
|
|||
num_batches = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
_ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
|
@ -135,7 +135,7 @@ def batch_padding_performance_1d():
|
|||
num_batches = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
_ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
|
@ -151,7 +151,7 @@ def batch_pyfunc_padding_3d():
|
|||
num_batches = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
_ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
|
@ -166,7 +166,7 @@ def batch_pyfunc_padding_1d():
|
|||
num_batches = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
_ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_random_color(degrees=(0.1, 1.9), plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -79,7 +79,7 @@ def test_random_color(degrees=(0.1, 1.9), plot=False):
|
|||
|
||||
ds_random_color = ds_random_color.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_random_color):
|
||||
for idx, (image, _) in enumerate(ds_random_color):
|
||||
if idx == 0:
|
||||
images_random_color = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
|
|
@ -256,7 +256,7 @@ def test_random_color_adjust_op_hue(plot=False):
|
|||
# pylint: disable=unnecessary-lambda
|
||||
def test_random_color_adjust_grayscale():
|
||||
"""
|
||||
Tests that the random color adjust works for grayscale images
|
||||
Tests that the random color adjust works for grayscale images
|
||||
"""
|
||||
|
||||
def channel_swap(image):
|
||||
|
@ -284,7 +284,7 @@ def test_random_color_adjust_grayscale():
|
|||
for item1 in data1.create_dict_iterator():
|
||||
c_image = item1["image"]
|
||||
dataset_shape_1.append(c_image.shape)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
|
|
@ -200,7 +200,7 @@ def test_random_crop_04_c():
|
|||
for item in data.create_dict_iterator():
|
||||
image = item["image"]
|
||||
image_list.append(image.shape)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
def test_random_crop_04_py():
|
||||
|
@ -227,7 +227,7 @@ def test_random_crop_04_py():
|
|||
for item in data.create_dict_iterator():
|
||||
image = (item["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
image_list.append(image.shape)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
def test_random_crop_05_c():
|
||||
|
@ -439,7 +439,7 @@ def test_random_crop_09():
|
|||
for item in data.create_dict_iterator():
|
||||
image = item["image"]
|
||||
image_list.append(image.shape)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "should be PIL Image" in str(e)
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_random_resize_op():
|
|||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
image_de_resized = item["image"]
|
||||
_ = item["image"]
|
||||
# Uncomment below line if you want to visualize images
|
||||
# visualize(image_de_resized, image_np_resized, mse)
|
||||
num_iter += 1
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -79,7 +79,7 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False):
|
|||
|
||||
ds_random_sharpness = ds_random_sharpness.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_random_sharpness):
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness):
|
||||
if idx == 0:
|
||||
images_random_sharpness = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore import log as logger
|
|||
|
||||
def test_sequential_sampler(print_res=False):
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_samples, num_repeats=None):
|
||||
sampler = ds.SequentialSampler()
|
||||
|
@ -36,7 +36,7 @@ def test_sequential_sampler(print_res=False):
|
|||
for item in data1.create_dict_iterator():
|
||||
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
|
||||
.format(item["image"].shape[0], item["label"].item()))
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
res.append(map_[(item["image"].shape[0], item["label"].item())])
|
||||
if print_res:
|
||||
logger.info("image.shapes and labels: {}".format(res))
|
||||
return res
|
||||
|
@ -48,7 +48,7 @@ def test_sequential_sampler(print_res=False):
|
|||
|
||||
def test_random_sampler(print_res=False):
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(replacement, num_samples, num_repeats):
|
||||
sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
|
||||
|
@ -56,7 +56,7 @@ def test_random_sampler(print_res=False):
|
|||
data1 = data1.repeat(num_repeats)
|
||||
res = []
|
||||
for item in data1.create_dict_iterator():
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
res.append(map_[(item["image"].shape[0], item["label"].item())])
|
||||
if print_res:
|
||||
logger.info("image.shapes and labels: {}".format(res))
|
||||
return res
|
||||
|
@ -71,7 +71,7 @@ def test_random_sampler(print_res=False):
|
|||
|
||||
def test_random_sampler_multi_iter(print_res=False):
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(replacement, num_samples, num_repeats, validate):
|
||||
sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
|
||||
|
@ -79,7 +79,7 @@ def test_random_sampler_multi_iter(print_res=False):
|
|||
while num_repeats > 0:
|
||||
res = []
|
||||
for item in data1.create_dict_iterator():
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
res.append(map_[(item["image"].shape[0], item["label"].item())])
|
||||
if print_res:
|
||||
logger.info("image.shapes and labels: {}".format(res))
|
||||
if validate != sorted(res):
|
||||
|
@ -112,7 +112,7 @@ def test_sampler_py_api():
|
|||
|
||||
def test_python_sampler():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
class Sp1(ds.Sampler):
|
||||
def __iter__(self):
|
||||
|
@ -138,7 +138,7 @@ def test_python_sampler():
|
|||
for item in data1.create_dict_iterator():
|
||||
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
|
||||
.format(item["image"].shape[0], item["label"].item()))
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
res.append(map_[(item["image"].shape[0], item["label"].item())])
|
||||
# print(res)
|
||||
return res
|
||||
|
||||
|
@ -167,7 +167,7 @@ def test_python_sampler():
|
|||
|
||||
def test_subset_sampler():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_samples, start_index, subset_size):
|
||||
sampler = ds.SubsetSampler(start_index, subset_size)
|
||||
|
@ -175,7 +175,7 @@ def test_subset_sampler():
|
|||
|
||||
res = []
|
||||
for item in d.create_dict_iterator():
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
res.append(map_[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
return res
|
||||
|
||||
|
@ -196,7 +196,7 @@ def test_subset_sampler():
|
|||
|
||||
def test_sampler_chain():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_shards, shard_id):
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, False)
|
||||
|
@ -209,7 +209,7 @@ def test_sampler_chain():
|
|||
for item in data1.create_dict_iterator():
|
||||
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
|
||||
.format(item["image"].shape[0], item["label"].item()))
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
res.append(map_[(item["image"].shape[0], item["label"].item())])
|
||||
return res
|
||||
|
||||
assert test_config(2, 0) == [0, 2, 4]
|
||||
|
@ -222,7 +222,7 @@ def test_sampler_chain():
|
|||
|
||||
def test_add_sampler_invalid_input():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
data1 = ds.ManifestDataset(manifest_file)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
|
|
|
@ -18,9 +18,8 @@ Testing dataset serialize and deserialize in DE
|
|||
import filecmp
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as c
|
||||
|
@ -28,6 +27,8 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
|
|||
from mindspore import log as logger
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
|
||||
from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
|
||||
|
||||
|
||||
def test_imagefolder(remove_json_files=True):
|
||||
"""
|
||||
|
@ -186,7 +187,7 @@ def test_random_crop():
|
|||
# Serializing into python dictionary
|
||||
ds1_dict = ds.serialize(data1)
|
||||
# Serializing into json object
|
||||
ds1_json = json.dumps(ds1_dict, indent=2)
|
||||
_ = json.dumps(ds1_dict, indent=2)
|
||||
|
||||
# Reconstruct dataset pipeline from its serialized form
|
||||
data1_1 = ds.deserialize(input_dict=ds1_dict)
|
||||
|
@ -198,7 +199,7 @@ def test_random_crop():
|
|||
for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
|
||||
data2.create_dict_iterator()):
|
||||
assert np.array_equal(item1['image'], item1_1['image'])
|
||||
image2 = item2["image"]
|
||||
_ = item2["image"]
|
||||
|
||||
|
||||
def validate_jsonfile(filepath):
|
||||
|
@ -221,10 +222,6 @@ def delete_json_files():
|
|||
|
||||
|
||||
# Test save load minddataset
|
||||
from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME, FILES_NUM, \
|
||||
FileWriter, Inter
|
||||
|
||||
|
||||
def test_minddataset(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
@ -247,7 +244,7 @@ def test_minddataset(add_and_remove_cv_file):
|
|||
|
||||
assert ds1_json == ds2_json
|
||||
|
||||
data = get_data(CV_DIR_NAME)
|
||||
_ = get_data(CV_DIR_NAME)
|
||||
assert data_set.get_dataset_size() == 5
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
|
|
|
@ -152,7 +152,7 @@ def test_shuffle_exception_01():
|
|||
data1 = data1.shuffle(buffer_size=-1)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
|
||||
|
@ -170,7 +170,7 @@ def test_shuffle_exception_02():
|
|||
data1 = data1.shuffle(buffer_size=0)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
|
||||
|
@ -188,7 +188,7 @@ def test_shuffle_exception_03():
|
|||
data1 = data1.shuffle(buffer_size=1)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
|
||||
|
@ -206,7 +206,7 @@ def test_shuffle_exception_05():
|
|||
data1 = data1.shuffle()
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
|
||||
|
@ -224,7 +224,7 @@ def test_shuffle_exception_06():
|
|||
data1 = data1.shuffle(buffer_size=False)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
|
||||
|
@ -242,7 +242,7 @@ def test_shuffle_exception_07():
|
|||
data1 = data1.shuffle(buffer_size=True)
|
||||
sum([1 for _ in data1])
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
|
||||
|
|
|
@ -70,7 +70,6 @@ def test_skip_1():
|
|||
buf = []
|
||||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 0
|
||||
assert buf == []
|
||||
|
||||
|
||||
|
|
|
@ -29,47 +29,47 @@ text_file_data = ["This is a text file.", "Another file.", "Be happy every day."
|
|||
|
||||
def split_with_invalid_inputs(d):
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([])
|
||||
_, _ = d.split([])
|
||||
assert "sizes cannot be empty" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([5, 0.6])
|
||||
_, _ = d.split([5, 0.6])
|
||||
assert "sizes should be list of int or list of float" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([-1, 6])
|
||||
_, _ = d.split([-1, 6])
|
||||
assert "there should be no negative numbers" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([3, 1])
|
||||
_, _ = d.split([3, 1])
|
||||
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([5, 1])
|
||||
_, _ = d.split([5, 1])
|
||||
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
||||
_, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
||||
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([-0.5, 0.5])
|
||||
_, _ = d.split([-0.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([1.5, 0.5])
|
||||
_, _ = d.split([1.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([0.5, 0.6])
|
||||
_, _ = d.split([0.5, 0.6])
|
||||
assert "percentages do not sum up to 1" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([0.3, 0.6])
|
||||
_, _ = d.split([0.3, 0.6])
|
||||
assert "percentages do not sum up to 1" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([0.05, 0.95])
|
||||
_, _ = d.split([0.05, 0.95])
|
||||
assert "percentage 0.05 is too small" in str(info.value)
|
||||
|
||||
|
||||
|
@ -79,7 +79,7 @@ def test_unmappable_invalid_input():
|
|||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([4, 1])
|
||||
_, _ = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
|
@ -273,7 +273,7 @@ def test_mappable_invalid_input():
|
|||
|
||||
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([4, 1])
|
||||
_, _ = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ class Augment:
|
|||
def __init__(self, loss):
|
||||
self.loss = loss
|
||||
|
||||
def preprocess(self, input):
|
||||
return input
|
||||
def preprocess(self, input_):
|
||||
return input_
|
||||
|
||||
def update(self, data):
|
||||
self.loss = data["loss"]
|
||||
|
@ -143,7 +143,7 @@ def test_multiple_iterators():
|
|||
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
# 2nd dataset
|
||||
# 2nd dataset
|
||||
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
|
@ -175,7 +175,7 @@ def test_sync_exception_01():
|
|||
|
||||
try:
|
||||
dataset = dataset.shuffle(shuffle_size)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
assert "shuffle" in str(e)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
|
@ -197,7 +197,7 @@ def test_sync_exception_02():
|
|||
|
||||
try:
|
||||
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
assert "name" in str(e)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ def test_take_01():
|
|||
data1 = data1.take(1)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for _, d in enumerate(data1):
|
||||
assert d[0][0] == 0
|
||||
|
||||
|
@ -63,7 +63,7 @@ def test_take_02():
|
|||
data1 = data1.take(2)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data1):
|
||||
assert i % 2 == d[0][0]
|
||||
|
||||
|
@ -80,7 +80,7 @@ def test_take_03():
|
|||
data1 = data1.take(3)
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
# Here i refers to index, d refers to data elements
|
||||
for i, d in enumerate(data1):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
|
|
|
@ -12,15 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore._c_dataengine as cde
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore.dataset.text import to_str, to_bytes
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore._c_dataengine as cde
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore.dataset.text import to_str
|
||||
|
||||
# pylint: disable=comparison-with-itself
|
||||
def test_basic():
|
||||
|
@ -34,7 +32,7 @@ def compare(strings):
|
|||
arr = np.array(strings, dtype='S')
|
||||
|
||||
def gen():
|
||||
yield arr,
|
||||
(yield arr,)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
|
||||
|
@ -50,7 +48,7 @@ def test_generator():
|
|||
|
||||
def test_batching_strings():
|
||||
def gen():
|
||||
yield np.array(["ab", "cde", "121"], dtype='S'),
|
||||
yield (np.array(["ab", "cde", "121"], dtype='S'),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
|
||||
|
||||
|
@ -62,7 +60,7 @@ def test_batching_strings():
|
|||
|
||||
def test_map():
|
||||
def gen():
|
||||
yield np.array(["ab cde 121"], dtype='S'),
|
||||
yield (np.array(["ab cde 121"], dtype='S'),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
|
||||
|
@ -79,7 +77,7 @@ def test_map():
|
|||
|
||||
def test_map2():
|
||||
def gen():
|
||||
yield np.array(["ab cde 121"], dtype='S'),
|
||||
yield (np.array(["ab cde 121"], dtype='S'),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
|
||||
|
|
|
@ -215,7 +215,7 @@ def test_case_tf_file_no_schema_columns_list():
|
|||
assert row["col_sint16"] == [-32768]
|
||||
|
||||
with pytest.raises(KeyError) as info:
|
||||
a = row["col_sint32"]
|
||||
_ = row["col_sint32"]
|
||||
assert "col_sint32" in str(info.value)
|
||||
|
||||
|
||||
|
@ -234,7 +234,7 @@ def test_tf_record_schema_columns_list():
|
|||
assert row["col_sint16"] == [-32768]
|
||||
|
||||
with pytest.raises(KeyError) as info:
|
||||
a = row["col_sint32"]
|
||||
_ = row["col_sint32"]
|
||||
assert "col_sint32" in str(info.value)
|
||||
|
||||
|
||||
|
@ -246,7 +246,7 @@ def test_case_invalid_files():
|
|||
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
row = data.create_dict_iterator().get_next()
|
||||
_ = data.create_dict_iterator().get_next()
|
||||
assert "cannot be opened" in str(info.value)
|
||||
assert "not valid tfrecord files" in str(info.value)
|
||||
assert valid_file not in str(info.value)
|
||||
|
|
|
@ -123,7 +123,7 @@ def test_to_type_03():
|
|||
]
|
||||
transform = py_vision.ComposeOp(transforms)
|
||||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Numpy" in str(e)
|
||||
|
||||
|
@ -145,7 +145,7 @@ def test_to_type_04():
|
|||
]
|
||||
transform = py_vision.ComposeOp(transforms)
|
||||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "missing" in str(e)
|
||||
|
||||
|
@ -167,7 +167,7 @@ def test_to_type_05():
|
|||
]
|
||||
transform = py_vision.ComposeOp(transforms)
|
||||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "data type" in str(e)
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_uniform_augment(plot=False, num_ops=2):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -87,7 +87,7 @@ def test_uniform_augment(plot=False, num_ops=2):
|
|||
|
||||
ds_ua = ds_ua.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_ua):
|
||||
for idx, (image, _) in enumerate(ds_ua):
|
||||
if idx == 0:
|
||||
images_ua = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -122,7 +122,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_original):
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -149,7 +149,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
|
|||
|
||||
ds_ua = ds_ua.batch(512)
|
||||
|
||||
for idx, (image, label) in enumerate(ds_ua):
|
||||
for idx, (image, _) in enumerate(ds_ua):
|
||||
if idx == 0:
|
||||
images_ua = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -180,9 +180,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
|||
F.Invert()]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
_ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "operations" in str(e)
|
||||
|
||||
|
@ -200,9 +200,9 @@ def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
|||
C.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
_ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
|
||||
|
@ -220,9 +220,9 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
|||
C.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
_ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
|
||||
|
@ -239,9 +239,9 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
|||
C.RandomRotation(degrees=45)]
|
||||
|
||||
try:
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
_ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "integer" in str(e)
|
||||
|
||||
|
@ -250,7 +250,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
|||
Test UniformAugment with greater crop size
|
||||
"""
|
||||
logger.info("Test CPP UniformAugment with random_crop bad input")
|
||||
batch_size=2
|
||||
batch_size = 2
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
|
||||
|
@ -266,9 +266,9 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
|||
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
|
||||
num_batches = 0
|
||||
try:
|
||||
for data in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
assert "Crop size" in str(e)
|
||||
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ def test_variable_size_batch():
|
|||
return batchInfo.get_epoch_num() + 1
|
||||
|
||||
def simple_copy(colList, batchInfo):
|
||||
_ = batchInfo
|
||||
return ([np.copy(arr) for arr in colList],)
|
||||
|
||||
def test_repeat_batch(gen_num, r, drop, func, res):
|
||||
|
@ -186,6 +187,7 @@ def test_batch_multi_col_map():
|
|||
yield (np.array([i]), np.array([i ** 2]))
|
||||
|
||||
def col1_col2_add_num(col1, col2, batchInfo):
|
||||
_ = batchInfo
|
||||
return ([[np.copy(arr + 100) for arr in col1],
|
||||
[np.copy(arr + 300) for arr in col2]])
|
||||
|
||||
|
@ -287,11 +289,11 @@ def test_exception():
|
|||
|
||||
def bad_batch_size(batchInfo):
|
||||
raise StopIteration
|
||||
return batchInfo.get_batch_num()
|
||||
#return batchInfo.get_batch_num()
|
||||
|
||||
def bad_map_func(col, batchInfo):
|
||||
raise StopIteration
|
||||
return (col,)
|
||||
#return (col,)
|
||||
|
||||
data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size)
|
||||
try:
|
||||
|
|
|
@ -143,7 +143,7 @@ def test_zip_exception_01():
|
|||
num_iter += 1
|
||||
logger.info("Number of data in zipped dataz: {}".format(num_iter))
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
@ -164,7 +164,7 @@ def test_zip_exception_02():
|
|||
num_iter += 1
|
||||
logger.info("Number of data in zipped dataz: {}".format(num_iter))
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
@ -185,7 +185,7 @@ def test_zip_exception_03():
|
|||
num_iter += 1
|
||||
logger.info("Number of data in zipped dataz: {}".format(num_iter))
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
@ -205,7 +205,7 @@ def test_zip_exception_04():
|
|||
num_iter += 1
|
||||
logger.info("Number of data in zipped dataz: {}".format(num_iter))
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
@ -226,7 +226,7 @@ def test_zip_exception_05():
|
|||
num_iter += 1
|
||||
logger.info("Number of data in zipped dataz: {}".format(num_iter))
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
@ -246,7 +246,7 @@ def test_zip_exception_06():
|
|||
num_iter += 1
|
||||
logger.info("Number of data in zipped dataz: {}".format(num_iter))
|
||||
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
|
|
|
@ -300,16 +300,16 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
|
|||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
with pytest.raises(ParamValueError):
|
||||
reader.read_at_page_by_id(0, "0", 1)
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
with pytest.raises(ParamValueError):
|
||||
reader.read_at_page_by_id(0, 0, "b")
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
with pytest.raises(ParamValueError):
|
||||
reader.read_at_page_by_name("822", "e", 1)
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
with pytest.raises(ParamValueError):
|
||||
reader.read_at_page_by_name("822", 0, "qwer")
|
||||
|
||||
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
|
||||
|
@ -330,14 +330,14 @@ def test_mindpage_filename_not_exist(fixture_cv_file):
|
|||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
with pytest.raises(MRMFetchDataError) as err:
|
||||
with pytest.raises(MRMFetchDataError):
|
||||
reader.read_at_page_by_id(9999, 0, 1)
|
||||
|
||||
with pytest.raises(MRMFetchDataError) as err:
|
||||
with pytest.raises(MRMFetchDataError):
|
||||
reader.read_at_page_by_name("abc.jpg", 0, 1)
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
with pytest.raises(ParamValueError):
|
||||
reader.read_at_page_by_name(1, 0, 1)
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
_ = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
|
|
|
@ -14,10 +14,9 @@
|
|||
"""test mnist to mindrecord tool"""
|
||||
import gzip
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileReader
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
# ============================================================================
|
||||
"""utils for test"""
|
||||
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -185,7 +185,7 @@ def get_nlp_data(dir_name, vocab_file, num):
|
|||
"""
|
||||
if not os.path.isdir(dir_name):
|
||||
raise IOError("Directory {} not exists".format(dir_name))
|
||||
for root, dirs, files in os.walk(dir_name):
|
||||
for root, _, files in os.walk(dir_name):
|
||||
for index, file_name_extension in enumerate(files):
|
||||
if index < num:
|
||||
file_path = os.path.join(root, file_name_extension)
|
||||
|
|
Loading…
Reference in New Issue