forked from mindspore-Ecosystem/mindspore
1. add more log info for dataset & mindrecord, 2. add two new testcase for MindDataset
This commit is contained in:
parent
c24252b2cc
commit
34e42bd6f9
|
@ -66,11 +66,10 @@ def _alter_node(node):
|
|||
|
||||
class Iterator:
|
||||
"""
|
||||
General Iterator over a dataset.
|
||||
|
||||
Attributes:
|
||||
dataset: Dataset to be iterated over
|
||||
General Iterator over a dataset.
|
||||
|
||||
Attributes:
|
||||
dataset: Dataset to be iterated over
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
|
@ -86,6 +85,7 @@ class Iterator:
|
|||
root = self.__convert_node_postorder(self.dataset)
|
||||
self.depipeline.AssignRootNode(root)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
self._index = 0
|
||||
|
||||
def __is_tree_node(self, node):
|
||||
"""Check if a node is tree node."""
|
||||
|
@ -185,10 +185,7 @@ class Iterator:
|
|||
Iterator.__print_local(input_op, level + 1)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print the dataset tree
|
||||
|
||||
"""
|
||||
"""Print the dataset tree"""
|
||||
self.__print_local(self.dataset, 0)
|
||||
|
||||
def release(self):
|
||||
|
@ -202,7 +199,10 @@ class Iterator:
|
|||
def __next__(self):
|
||||
data = self.get_next()
|
||||
if not data:
|
||||
if self._index == 0:
|
||||
logger.warning("No records available.")
|
||||
raise StopIteration
|
||||
self._index += 1
|
||||
return data
|
||||
|
||||
def get_output_shapes(self):
|
||||
|
@ -234,7 +234,7 @@ class DictIterator(Iterator):
|
|||
|
||||
def get_next(self):
|
||||
"""
|
||||
Returns the next record in the dataset as dictionary
|
||||
Returns the next record in the dataset as dictionary
|
||||
|
||||
Returns:
|
||||
Dict, the next record in the dataset.
|
||||
|
@ -260,7 +260,7 @@ class TupleIterator(Iterator):
|
|||
|
||||
def get_next(self):
|
||||
"""
|
||||
Returns the next record in the dataset as a list
|
||||
Returns the next record in the dataset as a list
|
||||
|
||||
Returns:
|
||||
List, the next record in the dataset.
|
||||
|
|
|
@ -328,13 +328,20 @@ class FileWriter:
|
|||
self._generator.build()
|
||||
self._generator.write_to_db()
|
||||
|
||||
mindrecord_files = []
|
||||
index_files = []
|
||||
# change the file mode to 600
|
||||
for item in self._paths:
|
||||
if os.path.exists(item):
|
||||
os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
|
||||
mindrecord_files.append(item)
|
||||
index_file = item + ".db"
|
||||
if os.path.exists(index_file):
|
||||
os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
index_files.append(index_file)
|
||||
|
||||
logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format(
|
||||
mindrecord_files, index_files))
|
||||
|
||||
return ret
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
|
|||
import numpy as np
|
||||
import pytest
|
||||
from mindspore._c_dataengine import InterpolationMode
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
from mindspore import log as logger
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 3
|
||||
|
||||
|
||||
def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
decode_op = vision.Decode()
|
||||
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
||||
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
|
||||
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
|
||||
data_set = data_set.batch(2)
|
||||
data_set = data_set.repeat(2)
|
||||
num_iter = 0
|
||||
labels = []
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
labels.append(item["label"])
|
||||
assert num_iter == 10
|
||||
logger.info("repeat shuffle: {}".format(labels))
|
||||
assert len(labels) == 10
|
||||
assert labels[0:5] == labels[0:5]
|
||||
assert labels[0:5] != labels[5:5]
|
||||
|
||||
|
||||
def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
decode_op = vision.Decode()
|
||||
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
|
||||
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
|
||||
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
|
||||
data_set = data_set.batch(32, drop_remainder=True)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 0
|
||||
|
||||
|
||||
def test_cv_minddataset_issue_888(add_and_remove_cv_file):
|
||||
"""issue 888 test."""
|
||||
columns_list = ["data", "label"]
|
||||
|
|
Loading…
Reference in New Issue