forked from mindspore-Ecosystem/mindspore
Fix dataset serdes for MindDataset
This commit is contained in:
parent
c0c0b0985e
commit
ea297c0889
|
@ -127,9 +127,12 @@ def serialize_operations(node_repr, key, val):
|
|||
|
||||
def serialize_sampler(node_repr, val):
|
||||
"""Serialize sampler object to dictionary."""
|
||||
node_repr['sampler'] = val.__dict__
|
||||
node_repr['sampler']['sampler_module'] = type(val).__module__
|
||||
node_repr['sampler']['sampler_name'] = type(val).__name__
|
||||
if val is None:
|
||||
node_repr['sampler'] = None
|
||||
else:
|
||||
node_repr['sampler'] = val.__dict__
|
||||
node_repr['sampler']['sampler_module'] = type(val).__module__
|
||||
node_repr['sampler']['sampler_name'] = type(val).__name__
|
||||
|
||||
|
||||
def traverse(node):
|
||||
|
@ -253,9 +256,10 @@ def create_node(node):
|
|||
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
||||
|
||||
elif dataset_op == 'MindDataset':
|
||||
pyobj = pyclass(node['dataset_file'], node.get('column_list'),
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
|
||||
node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
|
||||
node.get('shard_id'), node.get('block_reader'))
|
||||
node.get('shard_id'), node.get('block_reader'), sampler)
|
||||
|
||||
elif dataset_op == 'TFRecordDataset':
|
||||
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
|
||||
|
@ -341,24 +345,25 @@ def create_node(node):
|
|||
|
||||
def construct_sampler(in_sampler):
|
||||
"""Instantiate Sampler object based on the information from dictionary['sampler']"""
|
||||
sampler_name = in_sampler['sampler_name']
|
||||
sampler_module = in_sampler['sampler_module']
|
||||
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
|
||||
sampler = None
|
||||
if sampler_name == 'DistributedSampler':
|
||||
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
|
||||
elif sampler_name == 'PKSampler':
|
||||
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
|
||||
elif sampler_name == 'RandomSampler':
|
||||
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
|
||||
elif sampler_name == 'SequentialSampler':
|
||||
sampler = sampler_class()
|
||||
elif sampler_name == 'SubsetRandomSampler':
|
||||
sampler = sampler_class(in_sampler['indices'])
|
||||
elif sampler_name == 'WeightedRandomSampler':
|
||||
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
|
||||
else:
|
||||
raise ValueError("Sampler type is unknown: " + sampler_name)
|
||||
if in_sampler is not None:
|
||||
sampler_name = in_sampler['sampler_name']
|
||||
sampler_module = in_sampler['sampler_module']
|
||||
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
|
||||
if sampler_name == 'DistributedSampler':
|
||||
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
|
||||
elif sampler_name == 'PKSampler':
|
||||
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
|
||||
elif sampler_name == 'RandomSampler':
|
||||
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
|
||||
elif sampler_name == 'SequentialSampler':
|
||||
sampler = sampler_class()
|
||||
elif sampler_name == 'SubsetRandomSampler':
|
||||
sampler = sampler_class(in_sampler['indices'])
|
||||
elif sampler_name == 'WeightedRandomSampler':
|
||||
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
|
||||
else:
|
||||
raise ValueError("Sampler type is unknown: " + sampler_name)
|
||||
|
||||
return sampler
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import filecmp
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
@ -28,7 +28,6 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
|
|||
from mindspore.dataset.transforms.vision import Inter
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def test_imagefolder(remove_json_files=True):
|
||||
"""
|
||||
Test simulating resnet50 dataset pipeline.
|
||||
|
@ -217,6 +216,38 @@ def delete_json_files():
|
|||
except IOError:
|
||||
logger.info("Error while deleting: {}".format(f))
|
||||
|
||||
# Test save load minddataset
|
||||
from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME, FILES_NUM, \
|
||||
FileWriter, Inter
|
||||
|
||||
def test_minddataset(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
indices = [1, 2, 3, 5, 7]
|
||||
sampler = ds.SubsetRandomSampler(indices)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
|
||||
# Serializing into python dictionary
|
||||
ds1_dict = ds.serialize(data_set)
|
||||
# Serializing into json object
|
||||
ds1_json = json.dumps(ds1_dict, sort_keys=True)
|
||||
|
||||
# Reconstruct dataset pipeline from its serialized form
|
||||
data_set = ds.deserialize(input_dict=ds1_dict)
|
||||
ds2_dict = ds.serialize(data_set)
|
||||
# Serializing into json object
|
||||
ds2_json = json.dumps(ds2_dict, sort_keys=True)
|
||||
|
||||
assert ds1_json == ds2_json
|
||||
|
||||
data = get_data(CV_DIR_NAME)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue