!29 Add some prompt information for ease of use

Merge pull request !29 from jonyguo/add_more_log_info_and_testcase
This commit is contained in:
mindspore-ci-bot 2020-03-30 22:27:19 +08:00 committed by Gitee
commit 4f5755003a
3 changed files with 63 additions and 10 deletions

View File

@ -66,11 +66,10 @@ def _alter_node(node):
class Iterator: class Iterator:
""" """
General Iterator over a dataset. General Iterator over a dataset.
Attributes:
dataset: Dataset to be iterated over
Attributes:
dataset: Dataset to be iterated over
""" """
def __init__(self, dataset): def __init__(self, dataset):
@ -86,6 +85,7 @@ class Iterator:
root = self.__convert_node_postorder(self.dataset) root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root) self.depipeline.AssignRootNode(root)
self.depipeline.LaunchTreeExec() self.depipeline.LaunchTreeExec()
self._index = 0
def __is_tree_node(self, node): def __is_tree_node(self, node):
"""Check if a node is tree node.""" """Check if a node is tree node."""
@ -185,10 +185,7 @@ class Iterator:
Iterator.__print_local(input_op, level + 1) Iterator.__print_local(input_op, level + 1)
def print(self): def print(self):
""" """Print the dataset tree"""
Print the dataset tree
"""
self.__print_local(self.dataset, 0) self.__print_local(self.dataset, 0)
def release(self): def release(self):
@ -202,7 +199,10 @@ class Iterator:
def __next__(self): def __next__(self):
data = self.get_next() data = self.get_next()
if not data: if not data:
if self._index == 0:
logger.warning("No records available.")
raise StopIteration raise StopIteration
self._index += 1
return data return data
def get_output_shapes(self): def get_output_shapes(self):
@ -234,7 +234,7 @@ class DictIterator(Iterator):
def get_next(self): def get_next(self):
""" """
Returns the next record in the dataset as dictionary Returns the next record in the dataset as dictionary
Returns: Returns:
Dict, the next record in the dataset. Dict, the next record in the dataset.
@ -260,7 +260,7 @@ class TupleIterator(Iterator):
def get_next(self): def get_next(self):
""" """
Returns the next record in the dataset as a list Returns the next record in the dataset as a list
Returns: Returns:
List, the next record in the dataset. List, the next record in the dataset.

View File

@ -328,13 +328,20 @@ class FileWriter:
self._generator.build() self._generator.build()
self._generator.write_to_db() self._generator.write_to_db()
mindrecord_files = []
index_files = []
# change the file mode to 600 # change the file mode to 600
for item in self._paths: for item in self._paths:
if os.path.exists(item): if os.path.exists(item):
os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
mindrecord_files.append(item)
index_file = item + ".db" index_file = item + ".db"
if os.path.exists(index_file): if os.path.exists(index_file):
os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR)
index_files.append(index_file)
logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format(
mindrecord_files, index_files))
return ret return ret

View File

@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
import numpy as np import numpy as np
import pytest import pytest
from mindspore._c_dataengine import InterpolationMode from mindspore._c_dataengine import InterpolationMode
from mindspore.dataset.transforms.vision import Inter
from mindspore import log as logger from mindspore import log as logger
import mindspore.dataset as ds import mindspore.dataset as ds
@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 3 assert data_set.get_dataset_size() == 3
def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
decode_op = vision.Decode()
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
data_set = data_set.batch(2)
data_set = data_set.repeat(2)
num_iter = 0
labels = []
for item in data_set.create_dict_iterator():
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
num_iter += 1
labels.append(item["label"])
assert num_iter == 10
logger.info("repeat shuffle: {}".format(labels))
assert len(labels) == 10
assert labels[0:5] == labels[0:5]
assert labels[0:5] != labels[5:5]
def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
decode_op = vision.Decode()
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
data_set = data_set.batch(32, drop_remainder=True)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
num_iter += 1
assert num_iter == 0
def test_cv_minddataset_issue_888(add_and_remove_cv_file): def test_cv_minddataset_issue_888(add_and_remove_cv_file):
"""issue 888 test.""" """issue 888 test."""
columns_list = ["data", "label"] columns_list = ["data", "label"]