forked from mindspore-Ecosystem/mindspore
!1369 dataset: delete StorageDataset
Merge pull request !1369 from ms_yan/del_storage
This commit is contained in:
commit
41456ac824
|
@ -19,7 +19,7 @@ can also create samplers with this module to sample data.
|
|||
"""
|
||||
|
||||
from .core.configuration import config
|
||||
from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
|
||||
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
|
||||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
|
||||
Schema, Shuffle, zip, RandomDataset
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
|
@ -27,7 +27,7 @@ from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, Seque
|
|||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
from .engine.graphdata import GraphData
|
||||
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
|
||||
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
|
||||
|
|
|
@ -29,7 +29,7 @@ from .samplers import *
|
|||
from ..core.configuration import config, ConfigurationManager
|
||||
|
||||
|
||||
__all__ = ["config", "ConfigurationManager", "zip", "StorageDataset",
|
||||
__all__ = ["config", "ConfigurationManager", "zip",
|
||||
"ImageFolderDatasetV2", "MnistDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
|
||||
|
|
|
@ -22,7 +22,6 @@ import glob
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
import multiprocessing
|
||||
import queue
|
||||
|
@ -40,7 +39,7 @@ from mindspore._c_expression import typing
|
|||
from mindspore import log as logger
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator
|
||||
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
|
@ -480,7 +479,7 @@ class Dataset:
|
|||
If input_columns not provided or empty, all columns will be used.
|
||||
|
||||
Args:
|
||||
predicate(callable): python callable which returns a boolean value.
|
||||
predicate(callable): python callable which returns a boolean value, if False then filter the element.
|
||||
input_columns: (list[str], optional): List of names of the input columns, when
|
||||
default=None, the predicate will be applied on all columns in the dataset.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset
|
||||
|
@ -899,7 +898,7 @@ class Dataset:
|
|||
|
||||
def get_distribution(output_dataset):
|
||||
dev_id = 0
|
||||
if isinstance(output_dataset, (StorageDataset, MindDataset)):
|
||||
if isinstance(output_dataset, (MindDataset)):
|
||||
return output_dataset.distribution, dev_id
|
||||
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
|
||||
ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
|
||||
|
@ -984,57 +983,6 @@ class Dataset:
|
|||
"""Create an Iterator over the dataset."""
|
||||
return self.create_tuple_iterator()
|
||||
|
||||
@staticmethod
|
||||
def read_dir(dir_path, schema, columns_list=None, num_parallel_workers=None,
|
||||
deterministic_output=True, prefetch_size=None, shuffle=False, seed=None, distribution=""):
|
||||
"""
|
||||
Append the path of all files in the dir_path to StorageDataset.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path to the directory that contains the dataset.
|
||||
schema (str): Path to the json schema file.
|
||||
columns_list (list[str], optional): List of columns to be read (default=None).
|
||||
If not provided, read all columns.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
|
||||
(default=None).
|
||||
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
|
||||
or not (default=True). If True, performance might be affected.
|
||||
prefetch_size (int, optional): Prefetch number of records ahead of the
|
||||
user's request (default=None).
|
||||
shuffle (bool, optional): Shuffle the list of files in the directory (default=False).
|
||||
seed (int, optional): Create a random generator with a fixed seed. If set to None,
|
||||
create a random seed (default=None).
|
||||
distribution (str, optional): The path of distribution config file (default="").
|
||||
|
||||
Returns:
|
||||
StorageDataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset folder does not exist.
|
||||
ValueError: If dataset folder permission denied.
|
||||
"""
|
||||
logger.warning("WARN_DEPRECATED: The usage of read_dir is deprecated, please use TFRecordDataset with GLOB.")
|
||||
|
||||
list_files = []
|
||||
|
||||
if not os.path.isdir(dir_path):
|
||||
raise ValueError("The dataset folder does not exist!")
|
||||
if not os.access(dir_path, os.R_OK):
|
||||
raise ValueError("The dataset folder permission denied!")
|
||||
|
||||
for root, _, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
list_files.append(os.path.join(root, file))
|
||||
|
||||
list_files.sort()
|
||||
|
||||
if shuffle:
|
||||
rand = random.Random(seed)
|
||||
rand.shuffle(list_files)
|
||||
|
||||
return StorageDataset(list_files, schema, distribution, columns_list, num_parallel_workers,
|
||||
deterministic_output, prefetch_size)
|
||||
|
||||
@property
|
||||
def input_indexs(self):
|
||||
return self._input_indexs
|
||||
|
@ -1818,7 +1766,7 @@ class FilterDataset(DatasetOp):
|
|||
|
||||
Args:
|
||||
input_dataset: Input Dataset to be mapped.
|
||||
predicate: python callable which returns a boolean value.
|
||||
predicate: python callable which returns a boolean value, if False then filter the element.
|
||||
input_columns: (list[str]): List of names of the input columns, when
|
||||
default=None, the predicate will be applied all columns in the dataset.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset
|
||||
|
@ -2157,123 +2105,6 @@ class TransferDataset(DatasetOp):
|
|||
self.iterator = TupleIterator(self)
|
||||
|
||||
|
||||
class StorageDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in various formats, including TFData format.
|
||||
|
||||
Args:
|
||||
dataset_files (list[str]): List of files to be read.
|
||||
schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset.
|
||||
distribution (str, optional): Path of distribution config file (default="").
|
||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
|
||||
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
|
||||
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
|
||||
or not (default=True). If True, performance might be affected.
|
||||
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If schema file failed to read.
|
||||
RuntimeError: If distribution file path is given but failed to read.
|
||||
"""
|
||||
|
||||
@check
|
||||
def __init__(self, dataset_files, schema, distribution="", columns_list=None, num_parallel_workers=None,
|
||||
deterministic_output=None, prefetch_size=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
logger.warning("WARN_DEPRECATED: The usage of StorageDataset is deprecated, please use TFRecordDataset.")
|
||||
self.dataset_files = dataset_files
|
||||
try:
|
||||
with open(schema, 'r') as load_f:
|
||||
json.load(load_f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise RuntimeError("Json decode error when load schema file")
|
||||
except Exception:
|
||||
raise RuntimeError("Schema file failed to load")
|
||||
|
||||
if distribution != "":
|
||||
try:
|
||||
with open(distribution, 'r') as load_d:
|
||||
json.load(load_d)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise RuntimeError("Json decode error when load distribution file")
|
||||
except Exception:
|
||||
raise RuntimeError("Distribution file failed to load")
|
||||
if self.dataset_files is None:
|
||||
schema = None
|
||||
distribution = None
|
||||
self.schema = schema
|
||||
self.distribution = distribution
|
||||
self.columns_list = columns_list
|
||||
self.deterministic_output = deterministic_output
|
||||
self.prefetch_size = prefetch_size
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_files"] = self.dataset_files
|
||||
args["schema"] = self.schema
|
||||
args["distribution"] = self.distribution
|
||||
args["columns_list"] = self.columns_list
|
||||
args["deterministic_output"] = self.deterministic_output
|
||||
args["prefetch_size"] = self.prefetch_size
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
|
||||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
self._get_pipeline_info()
|
||||
return self._dataset_size
|
||||
|
||||
# manually set dataset_size as a temporary solution.
|
||||
def set_dataset_size(self, value):
|
||||
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
Get the number of classes in dataset.
|
||||
|
||||
Return:
|
||||
Number, number of classes.
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset type is invalid.
|
||||
ValueError: If dataset is not Imagenet dataset or manifest dataset.
|
||||
RuntimeError: If schema file is given but failed to load.
|
||||
"""
|
||||
cur_dataset = self
|
||||
while cur_dataset.input:
|
||||
cur_dataset = cur_dataset.input[0]
|
||||
if not hasattr(cur_dataset, "schema"):
|
||||
raise ValueError("Dataset type is invalid")
|
||||
# Only IMAGENET/MANIFEST support numclass
|
||||
try:
|
||||
with open(cur_dataset.schema, 'r') as load_f:
|
||||
load_dict = json.load(load_f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise RuntimeError("Json decode error when load schema file")
|
||||
except Exception:
|
||||
raise RuntimeError("Schema file failed to load")
|
||||
if load_dict["datasetType"] != "IMAGENET" and load_dict["datasetType"] != "MANIFEST":
|
||||
raise ValueError("%s dataset does not support num_classes!" % (load_dict["datasetType"]))
|
||||
|
||||
if self._num_classes is None:
|
||||
self._get_pipeline_info()
|
||||
return self._num_classes
|
||||
|
||||
def is_shuffled(self):
|
||||
return False
|
||||
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
|
|
|
@ -168,8 +168,6 @@ class Iterator:
|
|||
op_type = OpName.SKIP
|
||||
elif isinstance(dataset, de.TakeDataset):
|
||||
op_type = OpName.TAKE
|
||||
elif isinstance(dataset, de.StorageDataset):
|
||||
op_type = OpName.STORAGE
|
||||
elif isinstance(dataset, de.ImageFolderDatasetV2):
|
||||
op_type = OpName.IMAGEFOLDER
|
||||
elif isinstance(dataset, de.GeneratorDataset):
|
||||
|
|
|
@ -230,11 +230,7 @@ def create_node(node):
|
|||
pyobj = None
|
||||
# Find a matching Dataset class and call the constructor with the corresponding args.
|
||||
# When a new Dataset class is introduced, another if clause and parsing code needs to be added.
|
||||
if dataset_op == 'StorageDataset':
|
||||
pyobj = pyclass(node['dataset_files'], node['schema'], node.get('distribution'),
|
||||
node.get('columns_list'), node.get('num_parallel_workers'))
|
||||
|
||||
elif dataset_op == 'ImageFolderDatasetV2':
|
||||
if dataset_op == 'ImageFolderDatasetV2':
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
node.get('shuffle'), sampler, node.get('extensions'),
|
||||
|
|
|
@ -31,7 +31,7 @@ SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path)
|
|||
|
||||
def test_me_de_train_dataset():
|
||||
data_list = ["{0}/train-00001-of-01024.data".format(data_path)]
|
||||
data_set = ds.StorageDataset(data_list, schema=SCHEMA_DIR,
|
||||
data_set = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
|
||||
columns_list=["image/encoded", "image/class/label"])
|
||||
|
||||
resize_height = 224
|
||||
|
|
|
@ -24,11 +24,6 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
|||
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
|
||||
|
||||
DISTRIBUTION_ALL_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionAll.json"
|
||||
DISTRIBUTION_UNIQUE_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionUnique.json"
|
||||
DISTRIBUTION_RANDOM_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionRandom.json"
|
||||
DISTRIBUTION_EQUAL_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionEqualRows.json"
|
||||
|
||||
|
||||
def test_tf_file_normal():
|
||||
# apply dataset operations
|
||||
|
@ -42,61 +37,6 @@ def test_tf_file_normal():
|
|||
assert num_iter == 12
|
||||
|
||||
|
||||
def test_tf_file_distribution_all():
|
||||
# apply dataset operations
|
||||
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_ALL_DIR)
|
||||
data1 = data1.repeat(2)
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 24
|
||||
|
||||
|
||||
def test_tf_file_distribution_unique():
|
||||
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_UNIQUE_DIR)
|
||||
data1 = data1.repeat(1)
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_tf_file_distribution_random():
|
||||
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_RANDOM_DIR)
|
||||
data1 = data1.repeat(1)
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_tf_file_distribution_equal_rows():
|
||||
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_EQUAL_DIR)
|
||||
data1 = data1.repeat(2)
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info('=======test normal=======')
|
||||
test_tf_file_normal()
|
||||
|
||||
logger.info('=======test all=======')
|
||||
test_tf_file_distribution_all()
|
||||
|
||||
logger.info('=======test unique=======')
|
||||
test_tf_file_distribution_unique()
|
||||
|
||||
logger.info('=======test random=======')
|
||||
test_tf_file_distribution_random()
|
||||
logger.info('=======test equal rows=======')
|
||||
test_tf_file_distribution_equal_rows()
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/test_tf_file_3_images/data"
|
||||
SCHEMA = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
COLUMNS = ["label"]
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_case_0():
|
||||
logger.info("Test 0 readdir")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, columns_list=None, num_parallel_workers=None,
|
||||
deterministic_output=True, prefetch_size=None, shuffle=False, seed=None)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
logger.info("item[label] is {}".format(item["label"]))
|
||||
i = i + 1
|
||||
assert (i == 3)
|
||||
|
||||
|
||||
def test_case_1():
|
||||
logger.info("Test 1 readdir")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, COLUMNS, num_parallel_workers=None,
|
||||
deterministic_output=True, prefetch_size=None, shuffle=True, seed=None)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
logger.info("item[label] is {}".format(item["label"]))
|
||||
i = i + 1
|
||||
assert (i == 3)
|
||||
|
||||
|
||||
def test_case_2():
|
||||
logger.info("Test 2 readdir")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, columns_list=None, num_parallel_workers=2,
|
||||
deterministic_output=False, prefetch_size=16, shuffle=True, seed=10)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
logger.info("item[label] is {}".format(item["label"]))
|
||||
i = i + 1
|
||||
assert (i == 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_case_0()
|
||||
test_case_1()
|
||||
test_case_2()
|
|
@ -177,7 +177,7 @@ def test_random_crop():
|
|||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
# First dataset
|
||||
data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
decode_op = vision.Decode()
|
||||
random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||
data1 = data1.map(input_columns="image", operations=decode_op)
|
||||
|
@ -192,7 +192,7 @@ def test_random_crop():
|
|||
data1_1 = ds.deserialize(input_dict=ds1_dict)
|
||||
|
||||
# Second dataset
|
||||
data2 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
data2 = data2.map(input_columns="image", operations=decode_op)
|
||||
|
||||
for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from util import save_and_check
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
|
||||
"col_sint16", "col_sint32", "col_sint64"]
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_case_storage():
|
||||
"""
|
||||
test StorageDataset
|
||||
"""
|
||||
logger.info("Test Simple StorageDataset")
|
||||
# define parameters
|
||||
parameters = {"params": {}}
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
filename = "storage_result.npz"
|
||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_case_no_rows():
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json"
|
||||
|
||||
dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
assert dataset.get_dataset_size() == 3
|
||||
count = 0
|
||||
for data in dataset.create_tuple_iterator():
|
||||
count += 1
|
||||
assert count == 3
|
Loading…
Reference in New Issue