add TensorDataset and its ut

This commit is contained in:
ms_yan 2020-05-25 12:06:24 +08:00
parent 8de8289cfd
commit bc22c172b8
6 changed files with 411 additions and 3 deletions

View File

@ -19,7 +19,7 @@ can also create samplers with this module to sample data.
"""
from .core.configuration import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
@ -29,6 +29,6 @@ from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]

View File

@ -40,7 +40,7 @@ from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, \
check_rename, check_numpyslicesdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\
@ -4377,3 +4377,158 @@ class TextFileDataset(SourceDataset):
return self.num_shards > 1
return False
class _NumpySlicesDataset:
"""
Mainly for dealing with several kinds of format of python data, and return one row each time.
"""
def __init__(self, data, column_list=None):
self.column_list = None
# Convert dict data into tuple
if isinstance(data, dict) or isinstance(data[0], dict):
data = self.process_dict(data)
if isinstance(data[0], tuple) or isinstance(data, tuple):
self.is_tuple = True
self.data = data
if isinstance(data[0], tuple):
for i in range(len(self.data)):
self.data[i] = np.array(self.data[i])
else:
self.is_tuple = False
self.data = np.array(data)
# Init column_name
if column_list is not None:
self.column_list = column_list
elif self.column_list is None:
self.column_list = []
column_num = len(self.data) if self.is_tuple else 1
for i in range(column_num):
self.column_list.append("column_" + str(i))
def __getitem__(self, index):
if self.is_tuple:
data_row = []
for i in range(len(self.data)):
data_row.append(self.data[i][index, ...])
data_res = tuple(data_row)
else:
data_row = self.data[index, ...]
data_row = [data_row]
data_res = tuple(data_row)
return data_res
def __len__(self):
if self.is_tuple:
return len(self.data[0])
return len(self.data)
def process_dict(self, input_data):
"""
Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first.
"""
# When input is a tuple of dict, composing it
if isinstance(input_data, tuple) and isinstance(input_data[0], dict):
data_dict = {}
for d in input_data:
data_dict.update(d)
input_data = data_dict
# convert pandas like dict(has "values" column) into General dict
data_keys = list(input_data.keys())
data_col = input_data[data_keys[0]]
if hasattr(data_col, "values"):
new_dict = {}
for key in data_keys:
item1 = input_data.pop(key)
new_dict[key] = item1.values
input_data = new_dict
# Convert the data in dict into tuple
data = []
self.column_list = []
keys = input_data.keys()
for key in keys:
self.column_list.append(key)
value = input_data[key]
data.append(tuple(value))
return data
class NumpySlicesDataset(GeneratorDataset):
"""
Create a dataset with given data slices, mainly for loading python data into dataset.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Args:
datalist, tuple or dictInput of Given data, supported data type includes list, tuple, dict and other numpy
format. Input data will be sliced in first dimension and generate many rows, large data is not recommend to
load in this way as data is loading into memory.
column_names (list[str], optional): List of column names of the dataset (default=None). If column_names not
provided, when data is dict, column_names will be its key, otherwise it will be like column_1, column_2 ...
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
Examples:
>>> import mindspore.dataset as ds
>>> # 1) Input data can be a list
>>> data = [1, 2, 3]
>>> dataset1 = ds.NumpySlicesDataset(data, column_names=["column_1"])
>>> # 2) Input data can be a dict, and column_names will be its key
>>> data = {"a": [1, 2], "b": [3, 4]}
>>> dataset2 = ds.NumpySlicesDataset(data)
>>> # 3) Input data can be a tuple (or list of tuple), and each tuple element refers to data in each column
>>> data = ((1, 2), (3, 4), (5, 6))
>>> dataset3 = ds.NumpySlicesDataset(data, column_names=["column_1", "column_2", "column_3"])
>>> # 4) Load data from csv file
>>> import pandas as pd
>>> df = pd.read_csv("file.csv")
>>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False)
"""
@check_numpyslicesdataset
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None,
sampler=None, num_shards=None, shard_id=None):
dataset = _NumpySlicesDataset(data, column_names)
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)

View File

@ -1356,3 +1356,48 @@ def check_gnn_get_node_feature(method):
return method(*args, **kwargs)
return new_method
def check_numpyslicesdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check data; required argument
data = param_dict.get('data')
if not isinstance(data, (list, tuple, dict, np.ndarray)):
raise TypeError("Unsupported data type: {}, only support some common python data type, \
like list, tuple, dict, and numpy array.".format(type(data)))
if not data:
raise ValueError("Input data is empty.")
# check column_names
column_names = param_dict.get('column_names')
if column_names is not None:
check_columns(column_names, "column_names")
# check num of input column in column_names
column_num = 1 if isinstance(column_names, str) else len(column_names)
if isinstance(data, dict):
data_column = len(list(data.keys()))
if column_num != data_column:
raise ValueError("Num of column is {0}, but required is {1}.".format(column_num, data_column))
# Consider input is a tuple of dict
elif isinstance(data[0], dict):
data_column = np.sum(len(list(data[i].keys())) for i in range(len(data)))
if column_num != data_column:
raise ValueError("Num of column is {0}, but required is {1}.".format(column_num, data_column))
elif isinstance(data[0], tuple) or isinstance(data, tuple):
if column_num != len(data):
raise ValueError("Num of column is {0}, but required is {1}.".format(column_num, len(data)))
else:
if column_num != 1:
raise ValueError("Num of column is {0}, but required is {1} as data is list.".format(column_num, 1))
return method(*args, **kwargs)
return new_method

View File

@ -12,3 +12,4 @@ setuptools >= 40.8.0
matplotlib >= 3.1.3 # for ut test
opencv-python >= 4.2.0.32 # for ut test
sklearn >= 0.0 # for st test
pandas >= 1.0.2 # for ut test

View File

@ -0,0 +1,6 @@
age,sex,height,weight,slope,state,target
65,0,161,45,93,fixed,1
72,1,164,60,86,good,0
45,0,174,70,79,bad,1
73,1,173,65,70,good,1
55,1,182,80,104,good,0
1 age sex height weight slope state target
2 65 0 161 45 93 fixed 1
3 72 1 164 60 86 good 0
4 45 0 174 70 79 bad 1
5 73 1 173 65 70 good 1
6 55 1 182 80 104 good 0

View File

@ -0,0 +1,201 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as de
from mindspore import log as logger
import mindspore.dataset.transforms.vision.c_transforms as vision
import pandas as pd
def test_numpy_slices_list_1():
logger.info("Test Slicing a 1D list.")
np_data = [1, 2, 3]
ds = de.NumpySlicesDataset(np_data, shuffle=False)
for i, data in enumerate(ds):
assert data[0] == np_data[i]
def test_numpy_slices_list_2():
logger.info("Test Slicing a 2D list into 1D list.")
np_data = [[1, 2], [3, 4]]
ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
for i, data in enumerate(ds):
assert np.equal(data[0], np_data[i]).all()
def test_numpy_slices_list_3():
logger.info("Test Slicing list in the first dimension.")
np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
for i, data in enumerate(ds):
assert np.equal(data[0], np_data[i]).all()
def test_numpy_slices_list_append():
logger.info("Test reading data of image list.")
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
resize_height, resize_width = 2, 2
data1 = de.TFRecordDataset(DATA_DIR)
resize_op = vision.Resize((resize_height, resize_width))
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True), resize_op])
res = []
for data in data1.create_dict_iterator():
res.append(data["image"])
ds = de.NumpySlicesDataset(res, column_names=["col1"], shuffle=False)
for i, data in enumerate(ds):
assert np.equal(data, res[i]).all()
def test_numpy_slices_dict_1():
logger.info("Test Dictionary structure data.")
np_data = {"a": [1, 2], "b": [3, 4]}
ds = de.NumpySlicesDataset(np_data, shuffle=False)
res = [[1, 3], [2, 4]]
for i, data in enumerate(ds):
assert data[0] == res[i][0]
assert data[1] == res[i][1]
def test_numpy_slices_dict_2():
logger.info("Test input data is a tuple of Dictionary structure data.")
data1, data2 = {"a": [1, 2]}, {"b": [3, 4]}
ds = de.NumpySlicesDataset((data1, data2), column_names=["col1", "col2"], shuffle=False)
res = [[1, 3], [2, 4]]
for i, data in enumerate(ds):
assert data[0] == res[i][0]
assert data[1] == res[i][1]
def test_numpy_slices_tuple_1():
logger.info("Test slicing a list of tuple.")
np_data = [([1, 2], [3, 4]), ([11, 12], [13, 14]), ([21, 22], [23, 24])]
res = [[[1, 2], [11, 12], [21, 22]], [[3, 4], [13, 14], [23, 24]]]
ds = de.NumpySlicesDataset(np_data, shuffle=False)
for i, data in enumerate(ds):
assert np.equal(data[0], res[i][0]).all()
assert np.equal(data[1], res[i][1]).all()
assert np.equal(data[2], res[i][2]).all()
assert sum([1 for _ in ds]) == 2
def test_numpy_slices_tuple_2():
logger.info("Test reading different dimension of tuple data.")
features, labels = np.random.sample((5, 2)), np.random.sample((5, 1))
data = (features, labels)
ds = de.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
for i, data in enumerate(ds):
assert np.equal(data[0], features[i]).all()
assert data[1] == labels[i]
def test_numpy_slices_csv_value():
logger.info("Test loading value of csv file.")
csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
df = pd.read_csv(csv_file)
target = df.pop("target")
df.pop("state")
np_data = (df.values, target.values)
ds = de.NumpySlicesDataset(np_data, column_names=["col1", "col2"], shuffle=False)
for i, data in enumerate(ds):
assert np.equal(np_data[0][i], data[0]).all()
assert np.equal(np_data[1][i], data[1]).all()
def test_numpy_slices_csv_dict():
logger.info("Test loading csv file as dict.")
csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
df = pd.read_csv(csv_file)
df.pop("state")
res = df.values
ds = de.NumpySlicesDataset(dict(df), shuffle=False)
for i, data in enumerate(ds):
assert np.equal(data, res[i]).all()
def test_numpy_slices_num_samplers():
logger.info("Test num_samplers.")
np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
ds = de.NumpySlicesDataset(np_data, shuffle=False, num_samples=2)
for i, data in enumerate(ds):
assert np.equal(data[0], np_data[i]).all()
assert sum([1 for _ in ds]) == 2
def test_numpy_slices_distributed_sampler():
logger.info("Test distributed sampler.")
np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
ds = de.NumpySlicesDataset(np_data, shuffle=False, shard_id=0, num_shards=4)
for i, data in enumerate(ds):
assert np.equal(data[0], np_data[i * 4]).all()
assert sum([1 for _ in ds]) == 2
def test_numpy_slices_sequential_sampler():
logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
ds = de.NumpySlicesDataset(np_data, sampler=de.SequentialSampler()).repeat(2)
for i, data in enumerate(ds):
assert np.equal(data[0], np_data[i % 8]).all()
if __name__ == "__main__":
test_numpy_slices_list_1()
test_numpy_slices_list_2()
test_numpy_slices_list_3()
test_numpy_slices_list_append()
test_numpy_slices_dict_1()
test_numpy_slices_dict_2()
test_numpy_slices_tuple_1()
test_numpy_slices_tuple_2()
test_numpy_slices_csv_value()
test_numpy_slices_csv_dict()
test_numpy_slices_num_samplers()
test_numpy_slices_distributed_sampler()
test_numpy_slices_sequential_sampler()