diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 1812c0421a7..cf7050450b3 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -391,6 +391,30 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto return Status::OK(); } +Status DEPipeline::GetMindrecordSampler(const std::string &sampler_name, const py::dict &args, + std::shared_ptr *ptr) { + std::vector indices; + for (auto &arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "indices") { + indices = ToIntVector(value); + } else { + std::string err_msg = "ERROR: parameter " + key + " is invalid."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + } + if (sampler_name == "SubsetRandomSampler") { + *ptr = std::make_shared(indices); + } else { + std::string err_msg = "ERROR: parameter sampler_name is invalid."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *ptr) { if (args["dataset_file"].is_none()) { std::string err_msg = "Error: at least one of dataset_files is missing"; @@ -422,6 +446,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr(seed)); + } else if (key == "sampler_name") { + std::shared_ptr sample_op; + auto ret = GetMindrecordSampler(ToString(value), args["sampler_params"], &sample_op); + if (Status::OK() != ret) { + return ret; + } + operators.push_back(sample_op); } } } diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index acffc390cc0..491a75390e0 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -145,6 +145,9 @@ class DEPipeline { Status ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr); + Status GetMindrecordSampler(const std::string &sampler_name, const py::dict &args, + std::shared_ptr *ptr); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index 55319cabfe5..e18cbb75b98 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -68,6 +68,8 @@ enum ShardType { kCV = 1, }; +enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; + const double kEpsilon = 1e-7; const int kThreadNumber = 14; diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index f6b074a65d2..aeb3374f281 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -17,7 +17,9 @@ #ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ #define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#include #include +#include #include "mindrecord/include/shard_operator.h" namespace mindspore { @@ -30,6 +32,8 @@ class ShardSample : public ShardOperator { ShardSample(int num, int den, int par); + explicit ShardSample(const std::vector &indices); + ~ShardSample() override{}; const std::pair get_partitions() const; @@ -41,6 +45,8 @@ class ShardSample : public ShardOperator { int denominator_; int no_of_samples_; int partition_id_; + std::vector indices_; + SamplerType sampler_type_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index ea365a0e2ac..367c7a5cf9d 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -22,33 +22,37 @@ using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { -ShardSample::ShardSample(int n) { - numerator_ = 0; - denominator_ = 0; - no_of_samples_ = n; - partition_id_ = 0; -} +ShardSample::ShardSample(int n) + : numerator_(0), + denominator_(0), + no_of_samples_(n), + partition_id_(0), + indices_({}), + sampler_type_(kCustomTopNSampler) {} -ShardSample::ShardSample(int num, int den) { - if (num < 0 || den <= 0 || num > den) { - no_of_samples_ = 5; - numerator_ = 0; - denominator_ = 0; - partition_id_ = 0; - return; - } - numerator_ = num; - denominator_ = den; - no_of_samples_ = 0; - partition_id_ = 0; -} +ShardSample::ShardSample(int num, int den) + : numerator_(num), + denominator_(den), + no_of_samples_(0), + partition_id_(0), + indices_({}), + sampler_type_(kCustomTopPercentSampler) {} -ShardSample::ShardSample(int num, int den, int par) { - numerator_ = num; - denominator_ = den; - no_of_samples_ = 0; - partition_id_ = par; -} +ShardSample::ShardSample(int num, int den, int par) + : numerator_(num), + denominator_(den), + no_of_samples_(0), + partition_id_(par), + indices_({}), + sampler_type_(kCustomTopPercentSampler) {} + +ShardSample::ShardSample(const std::vector &indices) + : numerator_(0), + denominator_(0), + no_of_samples_(0), + partition_id_(0), + indices_(indices), + sampler_type_(kSubsetRandomSampler) {} const std::pair ShardSample::get_partitions() const { if (numerator_ == 1 && denominator_ > 1) { @@ -62,10 +66,15 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) { int total_no = static_cast(tasks.Size()); int taking = 0; - if (no_of_samples_ > 0) { // non sharding case constructor #1 + if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 no_of_samples_ = std::min(no_of_samples_, total_no); taking = no_of_samples_ - no_of_samples_ % no_of_categories; - } else { // constructor #2 & #3 + } else if (sampler_type_ == kSubsetRandomSampler) { + if (indices_.size() > total_no) { + MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; + return FAILED; + } + } else { // constructor TopPercent if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { if (numerator_ == 1 && denominator_ > 1) { // sharding taking = (total_no / denominator_) + (total_no % denominator_ == 0 ? 0 : 1); @@ -82,8 +91,15 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) { if (tasks.permutation_.empty()) { ShardTask new_tasks; total_no = static_cast(tasks.Size()); - for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start + if (sampler_type_ == kSubsetRandomSampler) { + for (int i = 0; i < indices_.size(); ++i) { + int index = ((indices_[i] % total_no) + total_no) % total_no; + new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python + } + } else { + for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start + } } std::swap(tasks, new_tasks); } else { diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2058bbf8264..3d660d58a8d 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1363,7 +1363,6 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): return samplers.SequentialSampler() - class ImageFolderDatasetV2(SourceDataset): """ A source dataset that reads images from a tree of directories. @@ -1621,6 +1620,9 @@ class MindDataset(SourceDataset): shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. block_reader (bool, optional): Whether read data by block mode (default=False). + sampler (Sampler, optional): Object used to choose samples from the + dataset (default=None, sampler is exclusive + with shuffle and block_reader). Support list: SubsetRandomSampler. Raises: ValueError: If num_shards is specified but shard_id is None. @@ -1630,14 +1632,16 @@ class MindDataset(SourceDataset): @check_minddataset def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, - shuffle=None, num_shards=None, shard_id=None, block_reader=False): + shuffle=None, num_shards=None, shard_id=None, + block_reader=False, sampler=None): super().__init__(num_parallel_workers) self.dataset_file = dataset_file self.columns_list = columns_list - self.global_shuffle = not bool(shuffle is False) + self.global_shuffle = shuffle self.distribution = "" + self.sampler = sampler - if num_shards is None: + if num_shards is None or shard_id is None: self.partitions = None else: self.partitions = [num_shards, shard_id] @@ -1645,9 +1649,25 @@ class MindDataset(SourceDataset): if block_reader is True and self.partitions is not None: raise ValueError("block reader not allowed true when use partitions") + if block_reader is True and shuffle is True: + raise ValueError("block reader not allowed true when use shuffle") + if block_reader is True: logger.warning("WARN: global shuffle is not used.") + if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False: + raise ValueError("the sampler is not supported yet.") + + # sampler exclusive + if block_reader is True and sampler is not None: + raise ValueError("block reader not allowed true when use sampler") + + if shuffle is True and sampler is not None: + raise ValueError("shuffle not allowed true when use sampler") + + if block_reader is False and sampler is None: + self.global_shuffle = not bool(shuffle is False) + self.num_shards = num_shards self.shard_id = shard_id self.block_reader = block_reader @@ -1661,6 +1681,9 @@ class MindDataset(SourceDataset): args["block_reader"] = self.block_reader args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + if self.sampler: + args["sampler_name"] = self.sampler.__class__.__name__ + args["sampler_params"] = self.sampler.__dict__ return args def get_dataset_size(self): diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py new file mode 100644 index 00000000000..7662a0e3900 --- /dev/null +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -0,0 +1,222 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This is the test module for mindrecord +""" +import collections +import json +import os +import re +import string + +import mindspore.dataset.transforms.vision.c_transforms as vision +import numpy as np +import pytest +from mindspore.dataset.transforms.vision import Inter +from mindspore import log as logger + +import mindspore.dataset as ds +from mindspore.mindrecord import FileWriter + +FILES_NUM = 4 +CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" +CV_DIR_NAME = "../data/mindrecord/testImageNetData" + + +@pytest.fixture +def add_and_remove_cv_file(): + """add/remove cv file""" + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + yield "yield_cv_data" + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + +def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + indices = [1, 2, 3, 5, 7] + sampler = ds.SubsetRandomSampler(indices) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + data = get_data(CV_DIR_NAME) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert data[indices[num_iter]]['file_name'] == "".join( + [chr(x) for x in item['file_name']]) + num_iter += 1 + assert num_iter == 5 + + +def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + indices = [1, 2, 2, 5, 7, 9] + sampler = ds.SubsetRandomSampler(indices) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + data = get_data(CV_DIR_NAME) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert data[indices[num_iter]]['file_name'] == "".join( + [chr(x) for x in item['file_name']]) + num_iter += 1 + assert num_iter == 6 + + +def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + indices = [] + sampler = ds.SubsetRandomSampler(indices) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + data = get_data(CV_DIR_NAME) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert data[indices[num_iter]]['file_name'] == "".join( + [chr(x) for x in item['file_name']]) + num_iter += 1 + assert num_iter == 0 + + +def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + indices = [1, 2, 4, 11, 13] + sampler = ds.SubsetRandomSampler(indices) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + data = get_data(CV_DIR_NAME) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert data[indices[num_iter] % len(data)]['file_name'] == "".join([ + chr(x) for x in item['file_name']]) + num_iter += 1 + assert num_iter == 5 + + +def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + indices = [1, 2, 4, -1, -2] + sampler = ds.SubsetRandomSampler(indices) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + data = get_data(CV_DIR_NAME) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert data[indices[num_iter] % len(data)]['file_name'] == "".join([ + chr(x) for x in item['file_name']]) + num_iter += 1 + assert num_iter == 5 + + +def get_data(dir_name): + """ + usage: get data from imagenet dataset + params: + dir_name: directory containing folder images and annotation information + + """ + if not os.path.isdir(dir_name): + raise IOError("Directory {} not exists".format(dir_name)) + img_dir = os.path.join(dir_name, "images") + ann_file = os.path.join(dir_name, "annotation.txt") + with open(ann_file, "r") as file_reader: + lines = file_reader.readlines() + + data_list = [] + for i, line in enumerate(lines): + try: + filename, label = line.split(",") + label = label.strip("\n") + with open(os.path.join(img_dir, filename), "rb") as file_reader: + img = file_reader.read() + data_json = {"id": i, + "file_name": filename, + "data": img, + "label": int(label)} + data_list.append(data_json) + except FileNotFoundError: + continue + return data_list