forked from mindspore-Ecosystem/mindspore
!228 [MD] add subset random sampler in minddataset
Merge pull request !228 from liyong126/mindrecord_subsetrandom_sampler
This commit is contained in:
commit
d949c17a7e
|
@ -391,6 +391,30 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DEPipeline::GetMindrecordSampler(const std::string &sampler_name, const py::dict &args,
|
||||||
|
std::shared_ptr<mindrecord::ShardOperator> *ptr) {
|
||||||
|
std::vector<int> indices;
|
||||||
|
for (auto &arg : args) {
|
||||||
|
std::string key = py::str(arg.first);
|
||||||
|
py::handle value = arg.second;
|
||||||
|
if (!value.is_none()) {
|
||||||
|
if (key == "indices") {
|
||||||
|
indices = ToIntVector(value);
|
||||||
|
} else {
|
||||||
|
std::string err_msg = "ERROR: parameter " + key + " is invalid.";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sampler_name == "SubsetRandomSampler") {
|
||||||
|
*ptr = std::make_shared<mindrecord::ShardSample>(indices);
|
||||||
|
} else {
|
||||||
|
std::string err_msg = "ERROR: parameter sampler_name is invalid.";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||||
if (args["dataset_file"].is_none()) {
|
if (args["dataset_file"].is_none()) {
|
||||||
std::string err_msg = "Error: at least one of dataset_files is missing";
|
std::string err_msg = "Error: at least one of dataset_files is missing";
|
||||||
|
@ -422,6 +446,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
||||||
} else if (key == "global_shuffle" && ToBool(value) == true) {
|
} else if (key == "global_shuffle" && ToBool(value) == true) {
|
||||||
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
|
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
|
||||||
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
|
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
|
||||||
|
} else if (key == "sampler_name") {
|
||||||
|
std::shared_ptr<mindrecord::ShardOperator> sample_op;
|
||||||
|
auto ret = GetMindrecordSampler(ToString(value), args["sampler_params"], &sample_op);
|
||||||
|
if (Status::OK() != ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
operators.push_back(sample_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,6 +145,9 @@ class DEPipeline {
|
||||||
|
|
||||||
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||||
|
|
||||||
|
Status GetMindrecordSampler(const std::string &sampler_name, const py::dict &args,
|
||||||
|
std::shared_ptr<mindrecord::ShardOperator> *ptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Execution tree that links the dataset operators.
|
// Execution tree that links the dataset operators.
|
||||||
std::shared_ptr<ExecutionTree> tree_;
|
std::shared_ptr<ExecutionTree> tree_;
|
||||||
|
|
|
@ -68,6 +68,8 @@ enum ShardType {
|
||||||
kCV = 1,
|
kCV = 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler };
|
||||||
|
|
||||||
const double kEpsilon = 1e-7;
|
const double kEpsilon = 1e-7;
|
||||||
|
|
||||||
const int kThreadNumber = 14;
|
const int kThreadNumber = 14;
|
||||||
|
|
|
@ -17,7 +17,9 @@
|
||||||
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
||||||
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
#include "mindrecord/include/shard_operator.h"
|
#include "mindrecord/include/shard_operator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -30,6 +32,8 @@ class ShardSample : public ShardOperator {
|
||||||
|
|
||||||
ShardSample(int num, int den, int par);
|
ShardSample(int num, int den, int par);
|
||||||
|
|
||||||
|
explicit ShardSample(const std::vector<int> &indices);
|
||||||
|
|
||||||
~ShardSample() override{};
|
~ShardSample() override{};
|
||||||
|
|
||||||
const std::pair<int, int> get_partitions() const;
|
const std::pair<int, int> get_partitions() const;
|
||||||
|
@ -41,6 +45,8 @@ class ShardSample : public ShardOperator {
|
||||||
int denominator_;
|
int denominator_;
|
||||||
int no_of_samples_;
|
int no_of_samples_;
|
||||||
int partition_id_;
|
int partition_id_;
|
||||||
|
std::vector<int> indices_;
|
||||||
|
SamplerType sampler_type_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,33 +22,37 @@ using mindspore::MsLogLevel::ERROR;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardSample::ShardSample(int n) {
|
ShardSample::ShardSample(int n)
|
||||||
numerator_ = 0;
|
: numerator_(0),
|
||||||
denominator_ = 0;
|
denominator_(0),
|
||||||
no_of_samples_ = n;
|
no_of_samples_(n),
|
||||||
partition_id_ = 0;
|
partition_id_(0),
|
||||||
}
|
indices_({}),
|
||||||
|
sampler_type_(kCustomTopNSampler) {}
|
||||||
|
|
||||||
ShardSample::ShardSample(int num, int den) {
|
ShardSample::ShardSample(int num, int den)
|
||||||
if (num < 0 || den <= 0 || num > den) {
|
: numerator_(num),
|
||||||
no_of_samples_ = 5;
|
denominator_(den),
|
||||||
numerator_ = 0;
|
no_of_samples_(0),
|
||||||
denominator_ = 0;
|
partition_id_(0),
|
||||||
partition_id_ = 0;
|
indices_({}),
|
||||||
return;
|
sampler_type_(kCustomTopPercentSampler) {}
|
||||||
}
|
|
||||||
numerator_ = num;
|
|
||||||
denominator_ = den;
|
|
||||||
no_of_samples_ = 0;
|
|
||||||
partition_id_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
ShardSample::ShardSample(int num, int den, int par) {
|
ShardSample::ShardSample(int num, int den, int par)
|
||||||
numerator_ = num;
|
: numerator_(num),
|
||||||
denominator_ = den;
|
denominator_(den),
|
||||||
no_of_samples_ = 0;
|
no_of_samples_(0),
|
||||||
partition_id_ = par;
|
partition_id_(par),
|
||||||
}
|
indices_({}),
|
||||||
|
sampler_type_(kCustomTopPercentSampler) {}
|
||||||
|
|
||||||
|
ShardSample::ShardSample(const std::vector<int> &indices)
|
||||||
|
: numerator_(0),
|
||||||
|
denominator_(0),
|
||||||
|
no_of_samples_(0),
|
||||||
|
partition_id_(0),
|
||||||
|
indices_(indices),
|
||||||
|
sampler_type_(kSubsetRandomSampler) {}
|
||||||
|
|
||||||
const std::pair<int, int> ShardSample::get_partitions() const {
|
const std::pair<int, int> ShardSample::get_partitions() const {
|
||||||
if (numerator_ == 1 && denominator_ > 1) {
|
if (numerator_ == 1 && denominator_ > 1) {
|
||||||
|
@ -62,10 +66,15 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
|
||||||
int total_no = static_cast<int>(tasks.Size());
|
int total_no = static_cast<int>(tasks.Size());
|
||||||
|
|
||||||
int taking = 0;
|
int taking = 0;
|
||||||
if (no_of_samples_ > 0) { // non sharding case constructor #1
|
if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1
|
||||||
no_of_samples_ = std::min(no_of_samples_, total_no);
|
no_of_samples_ = std::min(no_of_samples_, total_no);
|
||||||
taking = no_of_samples_ - no_of_samples_ % no_of_categories;
|
taking = no_of_samples_ - no_of_samples_ % no_of_categories;
|
||||||
} else { // constructor #2 & #3
|
} else if (sampler_type_ == kSubsetRandomSampler) {
|
||||||
|
if (indices_.size() > total_no) {
|
||||||
|
MS_LOG(ERROR) << "parameter indices's size is greater than dataset size.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else { // constructor TopPercent
|
||||||
if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
|
if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
|
||||||
if (numerator_ == 1 && denominator_ > 1) { // sharding
|
if (numerator_ == 1 && denominator_ > 1) { // sharding
|
||||||
taking = (total_no / denominator_) + (total_no % denominator_ == 0 ? 0 : 1);
|
taking = (total_no / denominator_) + (total_no % denominator_ == 0 ? 0 : 1);
|
||||||
|
@ -82,9 +91,16 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
|
||||||
if (tasks.permutation_.empty()) {
|
if (tasks.permutation_.empty()) {
|
||||||
ShardTask new_tasks;
|
ShardTask new_tasks;
|
||||||
total_no = static_cast<int>(tasks.Size());
|
total_no = static_cast<int>(tasks.Size());
|
||||||
|
if (sampler_type_ == kSubsetRandomSampler) {
|
||||||
|
for (int i = 0; i < indices_.size(); ++i) {
|
||||||
|
int index = ((indices_[i] % total_no) + total_no) % total_no;
|
||||||
|
new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||||
new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start
|
new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start
|
||||||
}
|
}
|
||||||
|
}
|
||||||
std::swap(tasks, new_tasks);
|
std::swap(tasks, new_tasks);
|
||||||
} else {
|
} else {
|
||||||
ShardTask new_tasks;
|
ShardTask new_tasks;
|
||||||
|
|
|
@ -1363,7 +1363,6 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||||
return samplers.SequentialSampler()
|
return samplers.SequentialSampler()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFolderDatasetV2(SourceDataset):
|
class ImageFolderDatasetV2(SourceDataset):
|
||||||
"""
|
"""
|
||||||
A source dataset that reads images from a tree of directories.
|
A source dataset that reads images from a tree of directories.
|
||||||
|
@ -1621,6 +1620,9 @@ class MindDataset(SourceDataset):
|
||||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||||
argument should be specified only when num_shards is also specified.
|
argument should be specified only when num_shards is also specified.
|
||||||
block_reader (bool, optional): Whether read data by block mode (default=False).
|
block_reader (bool, optional): Whether read data by block mode (default=False).
|
||||||
|
sampler (Sampler, optional): Object used to choose samples from the
|
||||||
|
dataset (default=None, sampler is exclusive
|
||||||
|
with shuffle and block_reader). Support list: SubsetRandomSampler.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If num_shards is specified but shard_id is None.
|
ValueError: If num_shards is specified but shard_id is None.
|
||||||
|
@ -1630,14 +1632,16 @@ class MindDataset(SourceDataset):
|
||||||
|
|
||||||
@check_minddataset
|
@check_minddataset
|
||||||
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
|
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
|
||||||
shuffle=None, num_shards=None, shard_id=None, block_reader=False):
|
shuffle=None, num_shards=None, shard_id=None,
|
||||||
|
block_reader=False, sampler=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.dataset_file = dataset_file
|
self.dataset_file = dataset_file
|
||||||
self.columns_list = columns_list
|
self.columns_list = columns_list
|
||||||
self.global_shuffle = not bool(shuffle is False)
|
self.global_shuffle = shuffle
|
||||||
self.distribution = ""
|
self.distribution = ""
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
if num_shards is None:
|
if num_shards is None or shard_id is None:
|
||||||
self.partitions = None
|
self.partitions = None
|
||||||
else:
|
else:
|
||||||
self.partitions = [num_shards, shard_id]
|
self.partitions = [num_shards, shard_id]
|
||||||
|
@ -1645,9 +1649,25 @@ class MindDataset(SourceDataset):
|
||||||
if block_reader is True and self.partitions is not None:
|
if block_reader is True and self.partitions is not None:
|
||||||
raise ValueError("block reader not allowed true when use partitions")
|
raise ValueError("block reader not allowed true when use partitions")
|
||||||
|
|
||||||
|
if block_reader is True and shuffle is True:
|
||||||
|
raise ValueError("block reader not allowed true when use shuffle")
|
||||||
|
|
||||||
if block_reader is True:
|
if block_reader is True:
|
||||||
logger.warning("WARN: global shuffle is not used.")
|
logger.warning("WARN: global shuffle is not used.")
|
||||||
|
|
||||||
|
if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False:
|
||||||
|
raise ValueError("the sampler is not supported yet.")
|
||||||
|
|
||||||
|
# sampler exclusive
|
||||||
|
if block_reader is True and sampler is not None:
|
||||||
|
raise ValueError("block reader not allowed true when use sampler")
|
||||||
|
|
||||||
|
if shuffle is True and sampler is not None:
|
||||||
|
raise ValueError("shuffle not allowed true when use sampler")
|
||||||
|
|
||||||
|
if block_reader is False and sampler is None:
|
||||||
|
self.global_shuffle = not bool(shuffle is False)
|
||||||
|
|
||||||
self.num_shards = num_shards
|
self.num_shards = num_shards
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
self.block_reader = block_reader
|
self.block_reader = block_reader
|
||||||
|
@ -1661,6 +1681,9 @@ class MindDataset(SourceDataset):
|
||||||
args["block_reader"] = self.block_reader
|
args["block_reader"] = self.block_reader
|
||||||
args["num_shards"] = self.num_shards
|
args["num_shards"] = self.num_shards
|
||||||
args["shard_id"] = self.shard_id
|
args["shard_id"] = self.shard_id
|
||||||
|
if self.sampler:
|
||||||
|
args["sampler_name"] = self.sampler.__class__.__name__
|
||||||
|
args["sampler_params"] = self.sampler.__dict__
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
|
|
@ -0,0 +1,222 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
This is the test module for mindrecord
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore.dataset.transforms.vision import Inter
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
|
FILES_NUM = 4
|
||||||
|
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
|
||||||
|
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def add_and_remove_cv_file():
|
||||||
|
"""add/remove cv file"""
|
||||||
|
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||||
|
for x in range(FILES_NUM)]
|
||||||
|
for x in paths:
|
||||||
|
if os.path.exists("{}".format(x)):
|
||||||
|
os.remove("{}".format(x))
|
||||||
|
if os.path.exists("{}.db".format(x)):
|
||||||
|
os.remove("{}.db".format(x))
|
||||||
|
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
cv_schema_json = {"id": {"type": "int32"},
|
||||||
|
"file_name": {"type": "string"},
|
||||||
|
"label": {"type": "int32"},
|
||||||
|
"data": {"type": "bytes"}}
|
||||||
|
writer.add_schema(cv_schema_json, "img_schema")
|
||||||
|
writer.add_index(["file_name", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
yield "yield_cv_data"
|
||||||
|
for x in paths:
|
||||||
|
os.remove("{}".format(x))
|
||||||
|
os.remove("{}.db".format(x))
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
indices = [1, 2, 3, 5, 7]
|
||||||
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
assert data_set.get_dataset_size() == 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info(
|
||||||
|
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
assert data[indices[num_iter]]['file_name'] == "".join(
|
||||||
|
[chr(x) for x in item['file_name']])
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
indices = [1, 2, 2, 5, 7, 9]
|
||||||
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
assert data_set.get_dataset_size() == 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info(
|
||||||
|
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
assert data[indices[num_iter]]['file_name'] == "".join(
|
||||||
|
[chr(x) for x in item['file_name']])
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
indices = []
|
||||||
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
assert data_set.get_dataset_size() == 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info(
|
||||||
|
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
assert data[indices[num_iter]]['file_name'] == "".join(
|
||||||
|
[chr(x) for x in item['file_name']])
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
indices = [1, 2, 4, 11, 13]
|
||||||
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
assert data_set.get_dataset_size() == 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info(
|
||||||
|
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
assert data[indices[num_iter] % len(data)]['file_name'] == "".join([
|
||||||
|
chr(x) for x in item['file_name']])
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
indices = [1, 2, 4, -1, -2]
|
||||||
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
assert data_set.get_dataset_size() == 10
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info(
|
||||||
|
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
|
logger.info(
|
||||||
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
assert data[indices[num_iter] % len(data)]['file_name'] == "".join([
|
||||||
|
chr(x) for x in item['file_name']])
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 5
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(dir_name):
|
||||||
|
"""
|
||||||
|
usage: get data from imagenet dataset
|
||||||
|
params:
|
||||||
|
dir_name: directory containing folder images and annotation information
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(dir_name):
|
||||||
|
raise IOError("Directory {} not exists".format(dir_name))
|
||||||
|
img_dir = os.path.join(dir_name, "images")
|
||||||
|
ann_file = os.path.join(dir_name, "annotation.txt")
|
||||||
|
with open(ann_file, "r") as file_reader:
|
||||||
|
lines = file_reader.readlines()
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
try:
|
||||||
|
filename, label = line.split(",")
|
||||||
|
label = label.strip("\n")
|
||||||
|
with open(os.path.join(img_dir, filename), "rb") as file_reader:
|
||||||
|
img = file_reader.read()
|
||||||
|
data_json = {"id": i,
|
||||||
|
"file_name": filename,
|
||||||
|
"data": img,
|
||||||
|
"label": int(label)}
|
||||||
|
data_list.append(data_json)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
return data_list
|
Loading…
Reference in New Issue