delete storageDataset Op API and its test case

This commit is contained in:
ms_yan 2020-05-22 17:07:21 +08:00
parent 6f733ec113
commit d5e896b51c
10 changed files with 11 additions and 366 deletions

View File

@ -19,7 +19,7 @@ can also create samplers with this module to sample data.
"""
from .core.configuration import config
from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
@ -27,7 +27,7 @@ from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, Seque
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",

View File

@ -29,7 +29,7 @@ from .samplers import *
from ..core.configuration import config, ConfigurationManager
__all__ = ["config", "ConfigurationManager", "zip", "StorageDataset",
__all__ = ["config", "ConfigurationManager", "zip",
"ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",

View File

@ -22,7 +22,6 @@ import glob
import json
import math
import os
import random
import uuid
import multiprocessing
import queue
@ -40,7 +39,7 @@ from mindspore._c_expression import typing
from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
@ -480,7 +479,7 @@ class Dataset:
If input_columns not provided or empty, all columns will be used.
Args:
predicate(callable): python callable which returns a boolean value.
predicate(callable): python callable which returns a boolean value, if False then filter the element.
input_columns: (list[str], optional): List of names of the input columns, when
default=None, the predicate will be applied on all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset
@ -899,7 +898,7 @@ class Dataset:
def get_distribution(output_dataset):
dev_id = 0
if isinstance(output_dataset, (StorageDataset, MindDataset)):
if isinstance(output_dataset, (MindDataset)):
return output_dataset.distribution, dev_id
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
@ -984,57 +983,6 @@ class Dataset:
"""Create an Iterator over the dataset."""
return self.create_tuple_iterator()
@staticmethod
def read_dir(dir_path, schema, columns_list=None, num_parallel_workers=None,
deterministic_output=True, prefetch_size=None, shuffle=False, seed=None, distribution=""):
"""
Append the path of all files in the dir_path to StorageDataset.
Args:
dir_path (str): Path to the directory that contains the dataset.
schema (str): Path to the json schema file.
columns_list (list[str], optional): List of columns to be read (default=None).
If not provided, read all columns.
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
(default=None).
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
or not (default=True). If True, performance might be affected.
prefetch_size (int, optional): Prefetch number of records ahead of the
user's request (default=None).
shuffle (bool, optional): Shuffle the list of files in the directory (default=False).
seed (int, optional): Create a random generator with a fixed seed. If set to None,
create a random seed (default=None).
distribution (str, optional): The path of distribution config file (default="").
Returns:
StorageDataset.
Raises:
ValueError: If dataset folder does not exist.
ValueError: If dataset folder permission denied.
"""
logger.warning("WARN_DEPRECATED: The usage of read_dir is deprecated, please use TFRecordDataset with GLOB.")
list_files = []
if not os.path.isdir(dir_path):
raise ValueError("The dataset folder does not exist!")
if not os.access(dir_path, os.R_OK):
raise ValueError("The dataset folder permission denied!")
for root, _, files in os.walk(dir_path):
for file in files:
list_files.append(os.path.join(root, file))
list_files.sort()
if shuffle:
rand = random.Random(seed)
rand.shuffle(list_files)
return StorageDataset(list_files, schema, distribution, columns_list, num_parallel_workers,
deterministic_output, prefetch_size)
@property
def input_indexs(self):
return self._input_indexs
@ -1818,7 +1766,7 @@ class FilterDataset(DatasetOp):
Args:
input_dataset: Input Dataset to be mapped.
predicate: python callable which returns a boolean value.
predicate: python callable which returns a boolean value, if False then filter the element.
input_columns: (list[str]): List of names of the input columns, when
default=None, the predicate will be applied all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset
@ -2157,123 +2105,6 @@ class TransferDataset(DatasetOp):
self.iterator = TupleIterator(self)
class StorageDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in various formats, including TFData format.
Args:
dataset_files (list[str]): List of files to be read.
schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset.
distribution (str, optional): Path of distribution config file (default="").
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
or not (default=True). If True, performance might be affected.
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
Raises:
RuntimeError: If schema file failed to read.
RuntimeError: If distribution file path is given but failed to read.
"""
@check
def __init__(self, dataset_files, schema, distribution="", columns_list=None, num_parallel_workers=None,
deterministic_output=None, prefetch_size=None):
super().__init__(num_parallel_workers)
logger.warning("WARN_DEPRECATED: The usage of StorageDataset is deprecated, please use TFRecordDataset.")
self.dataset_files = dataset_files
try:
with open(schema, 'r') as load_f:
json.load(load_f)
except json.decoder.JSONDecodeError:
raise RuntimeError("Json decode error when load schema file")
except Exception:
raise RuntimeError("Schema file failed to load")
if distribution != "":
try:
with open(distribution, 'r') as load_d:
json.load(load_d)
except json.decoder.JSONDecodeError:
raise RuntimeError("Json decode error when load distribution file")
except Exception:
raise RuntimeError("Distribution file failed to load")
if self.dataset_files is None:
schema = None
distribution = None
self.schema = schema
self.distribution = distribution
self.columns_list = columns_list
self.deterministic_output = deterministic_output
self.prefetch_size = prefetch_size
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
args["schema"] = self.schema
args["distribution"] = self.distribution
args["columns_list"] = self.columns_list
args["deterministic_output"] = self.deterministic_output
args["prefetch_size"] = self.prefetch_size
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if self._dataset_size is None:
self._get_pipeline_info()
return self._dataset_size
# manually set dataset_size as a temporary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
def num_classes(self):
"""
Get the number of classes in dataset.
Return:
Number, number of classes.
Raises:
ValueError: If dataset type is invalid.
ValueError: If dataset is not Imagenet dataset or manifest dataset.
RuntimeError: If schema file is given but failed to load.
"""
cur_dataset = self
while cur_dataset.input:
cur_dataset = cur_dataset.input[0]
if not hasattr(cur_dataset, "schema"):
raise ValueError("Dataset type is invalid")
# Only IMAGENET/MANIFEST support numclass
try:
with open(cur_dataset.schema, 'r') as load_f:
load_dict = json.load(load_f)
except json.decoder.JSONDecodeError:
raise RuntimeError("Json decode error when load schema file")
except Exception:
raise RuntimeError("Schema file failed to load")
if load_dict["datasetType"] != "IMAGENET" and load_dict["datasetType"] != "MANIFEST":
raise ValueError("%s dataset does not support num_classes!" % (load_dict["datasetType"]))
if self._num_classes is None:
self._get_pipeline_info()
return self._num_classes
def is_shuffled(self):
return False
def is_sharded(self):
return False
class RangeDataset(MappableDataset):
"""

View File

@ -168,8 +168,6 @@ class Iterator:
op_type = OpName.SKIP
elif isinstance(dataset, de.TakeDataset):
op_type = OpName.TAKE
elif isinstance(dataset, de.StorageDataset):
op_type = OpName.STORAGE
elif isinstance(dataset, de.ImageFolderDatasetV2):
op_type = OpName.IMAGEFOLDER
elif isinstance(dataset, de.GeneratorDataset):

View File

@ -230,11 +230,7 @@ def create_node(node):
pyobj = None
# Find a matching Dataset class and call the constructor with the corresponding args.
# When a new Dataset class is introduced, another if clause and parsing code needs to be added.
if dataset_op == 'StorageDataset':
pyobj = pyclass(node['dataset_files'], node['schema'], node.get('distribution'),
node.get('columns_list'), node.get('num_parallel_workers'))
elif dataset_op == 'ImageFolderDatasetV2':
if dataset_op == 'ImageFolderDatasetV2':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('extensions'),

View File

@ -31,7 +31,7 @@ SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path)
def test_me_de_train_dataset():
data_list = ["{0}/train-00001-of-01024.data".format(data_path)]
data_set = ds.StorageDataset(data_list, schema=SCHEMA_DIR,
data_set = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
columns_list=["image/encoded", "image/class/label"])
resize_height = 224

View File

@ -24,11 +24,6 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
DISTRIBUTION_ALL_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionAll.json"
DISTRIBUTION_UNIQUE_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionUnique.json"
DISTRIBUTION_RANDOM_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionRandom.json"
DISTRIBUTION_EQUAL_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionEqualRows.json"
def test_tf_file_normal():
# apply dataset operations
@ -42,61 +37,6 @@ def test_tf_file_normal():
assert num_iter == 12
def test_tf_file_distribution_all():
# apply dataset operations
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_ALL_DIR)
data1 = data1.repeat(2)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 24
def test_tf_file_distribution_unique():
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_UNIQUE_DIR)
data1 = data1.repeat(1)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_tf_file_distribution_random():
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_RANDOM_DIR)
data1 = data1.repeat(1)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
def test_tf_file_distribution_equal_rows():
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_EQUAL_DIR)
data1 = data1.repeat(2)
num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1
assert num_iter == 4
if __name__ == '__main__':
logger.info('=======test normal=======')
test_tf_file_normal()
logger.info('=======test all=======')
test_tf_file_distribution_all()
logger.info('=======test unique=======')
test_tf_file_distribution_unique()
logger.info('=======test random=======')
test_tf_file_distribution_random()
logger.info('=======test equal rows=======')
test_tf_file_distribution_equal_rows()

View File

@ -1,69 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = "../data/dataset/test_tf_file_3_images/data"
SCHEMA = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
COLUMNS = ["label"]
GENERATE_GOLDEN = False
def test_case_0():
logger.info("Test 0 readdir")
# apply dataset operations
data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, columns_list=None, num_parallel_workers=None,
deterministic_output=True, prefetch_size=None, shuffle=False, seed=None)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
logger.info("item[label] is {}".format(item["label"]))
i = i + 1
assert (i == 3)
def test_case_1():
logger.info("Test 1 readdir")
# apply dataset operations
data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, COLUMNS, num_parallel_workers=None,
deterministic_output=True, prefetch_size=None, shuffle=True, seed=None)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
logger.info("item[label] is {}".format(item["label"]))
i = i + 1
assert (i == 3)
def test_case_2():
logger.info("Test 2 readdir")
# apply dataset operations
data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, columns_list=None, num_parallel_workers=2,
deterministic_output=False, prefetch_size=16, shuffle=True, seed=10)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
logger.info("item[label] is {}".format(item["label"]))
i = i + 1
assert (i == 3)
if __name__ == "__main__":
test_case_0()
test_case_1()
test_case_2()

View File

@ -177,7 +177,7 @@ def test_random_crop():
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
# First dataset
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
decode_op = vision.Decode()
random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
data1 = data1.map(input_columns="image", operations=decode_op)
@ -192,7 +192,7 @@ def test_random_crop():
data1_1 = ds.deserialize(input_dict=ds1_dict)
# Second dataset
data2 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
data2 = data2.map(input_columns="image", operations=decode_op)
for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),

View File

@ -1,51 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from util import save_and_check
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False
def test_case_storage():
"""
test StorageDataset
"""
logger.info("Test Simple StorageDataset")
# define parameters
parameters = {"params": {}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
filename = "storage_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_case_no_rows():
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
assert dataset.get_dataset_size() == 3
count = 0
for data in dataset.create_tuple_iterator():
count += 1
assert count == 3