diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 076f2ecc364..6bacd673961 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -53,6 +53,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" #include "dataset/kernels/data/to_float16_op.h" @@ -415,6 +416,7 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "SequentialSampler") .def(py::init<>()); + (void)py::class_>(*m, "SubsetRandomSampler") .def(py::init>(), py::arg("indices")); @@ -425,6 +427,9 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), py::arg("replacement")); + + (void)py::class_>(*m, "PythonSampler") + .def(py::init(), py::arg("pySampler")); } void bindInfoObjects(py::module *m) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 5d55c8276a5..b084e1c1254 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(engine-datasetops-source-sampler OBJECT distributed_sampler.cc pk_sampler.cc + python_sampler.cc random_sampler.cc sampler.cc sequential_sampler.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc new file mode 100644 index 00000000000..464717feb45 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" + +#include + +namespace mindspore { +namespace dataset { + +PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) + : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} + +Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { + if (need_to_reset_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + std::shared_ptr sample_ids; + { + py::gil_scoped_acquire gil_acquire; + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object py_ret = py_sampler_instance.attr("_get_indices")(); + py::array np_sample_ids = py_ret.cast(); + Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + need_to_reset_ = true; + } + return Status::OK(); +} + +Status PythonSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +Status PythonSampler::Reset() { + CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); + need_to_reset_ = false; + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("reset")(); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h new file mode 100644 index 00000000000..b8734fee6af --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ + +#include +#include + +#include "dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PythonSampler : public Sampler { + public: + // Constructor + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PythonSampler(py::object py_sampler_instance, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~PythonSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status Reset() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextBuffer(std::unique_ptr *out_buffer) override; + + private: + bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() + + py::object py_sampler_instance; // The handle to the py_sampler python object +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 3c3f5f48e8e..9fe752448ad 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; - // check samples_per_buffer is properly set and doesn't overflow - CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); - // A call to derived class to get sample ids wrapped inside a buffer RETURN_IF_NOT_OK(GetNextBuffer(&db)); // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index a3c4fe22561..6ed06b527fd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) } Status SequentialSampler::InitSampler() { + num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; return Status::OK(); diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 479c66045fe..bff23b7abfd 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -23,7 +23,7 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ Shuffle, zip from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ - WeightedRandomSampler + WeightedRandomSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8de56a6dff2..71df50ac4a9 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2032,7 +2032,7 @@ class GeneratorDataset(SourceDataset): if self.sampler is not None and hasattr(source, "__getitem__"): if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler)): + samplers.WeightedRandomSampler, samplers.Sampler)): if num_samples is None: num_samples = len(source) sampler_instance = self.sampler.create() diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 0bba559210c..421a03ab8de 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -16,11 +16,90 @@ Sampler module provides several samplers to generate sampling data from dataset. There are following samplers: DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. +User can also define custom sampler by extending from Sampler class. """ import mindspore._c_dataengine as cde +import numpy as np -class DistributedSampler(): + +class Sampler: + """ + Base class for user defined sampler. + User defined sampler can be used with any existing dataset with sampler support. + + An required _iter_() method should by overridden by user for sample index generation. + An optional reset() method can be overridden for per repeat reset, + + dataset_size and num_samples will be set by dataset once a dataset iterator is created. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> class ReverseSampler(ds,Sampler): + >>> def __iter__(self): + >>> for i in range(self.dataset_size - 1, -1, -1): + >>> yield i + >>> + >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler()) + """ + + def __init__(self): + self.dataset_size = 0 + self.num_samples = 0 + + def __iter__(self): + """ + User defined iterator, must be overridden. + _handshake is guaranteed to be called prior to iterator construction + + """ + raise NotImplementedError + + def reset(self): + """ + Per repeat reset callback, override this method if necessary + """ + + # Initialization handshake callback + # Do not override this method! + def _handshake(self, ds_size, num_samples): + self.dataset_size = ds_size + self.num_samples = num_samples + + # Indices fetcher + # Do not override this method! + def _get_indices(self): + sampler_iter = iter(self) + ret = [] + for _ in range(self.num_samples): + try: + idx = next(sampler_iter) + ret.append(idx) + except StopIteration: + break + return np.array(ret) + + # Instance fetcher + # Do not override this method! + def create(self): + return cde.PythonSampler(self) + + +class BuiltinSampler: + """ + Base class for BuiltinSampler. + + User should not extend this class. + """ + def __init__(self): + pass + + def create(self): + pass + + +class DistributedSampler(BuiltinSampler): """ Sampler that access a shard of the dataset. @@ -65,7 +144,7 @@ class DistributedSampler(): return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) -class PKSampler(): +class PKSampler(BuiltinSampler): """ Samples K elements for each P class in the dataset. @@ -106,7 +185,7 @@ class PKSampler(): return cde.PKSampler(self.num_val, self.shuffle) -class RandomSampler(): +class RandomSampler(BuiltinSampler): """ Samples the elements randomly. @@ -147,7 +226,7 @@ class RandomSampler(): return cde.RandomSampler(self.replacement, self.num_samples) -class SequentialSampler(): +class SequentialSampler(BuiltinSampler): """ Samples the dataset elements sequentially, same as not having a sampler. @@ -165,7 +244,7 @@ class SequentialSampler(): return cde.SequentialSampler() -class SubsetRandomSampler(): +class SubsetRandomSampler(BuiltinSampler): """ Samples the elements randomly from a sequence of indices. @@ -196,7 +275,8 @@ class SubsetRandomSampler(): def _create_for_minddataset(self): return cde.MindrecordSubsetRandomSampler(self.indices) -class WeightedRandomSampler(): + +class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b74e913202f..ff56652bcbb 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -297,9 +297,7 @@ def check_sampler_shuffle_shard_options(param_dict): shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') - if sampler is not None and not isinstance(sampler, ( - samplers.DistributedSampler, samplers.PKSampler, samplers.RandomSampler, samplers.SequentialSampler, - samplers.SubsetRandomSampler, samplers.WeightedRandomSampler)): + if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise ValueError("sampler is not a valid Sampler type.") if sampler is not None: @@ -579,11 +577,11 @@ def check_generatordataset(method): raise ValueError("PKSampler is not supported by GeneratorDataset") if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler)): + samplers.WeightedRandomSampler, samplers.Sampler)): try: iter(sampler) except TypeError: - raise TypeError("sampler should be either iterable or from dataset.samplers.py") + raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers") return method(*args, **kwargs) diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 7a58249f9c3..4efca6f8187 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -14,6 +14,7 @@ # ============================================================================== import mindspore.dataset as ds from mindspore import log as logger +import numpy as np # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] @@ -107,8 +108,64 @@ def test_sampler_py_api(): sampler.get_indices() +def test_python_sampler(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + + class Sp1(ds.Sampler): + def __iter__(self): + return iter([i for i in range(self.dataset_size)]) + + class Sp2(ds.Sampler): + def __init__(self): + super(Sp2, self).__init__() + # at this stage, self.dataset_size and self.num_samples are not yet known + self.cnt = 0 + + def __iter__(self): # first epoch, all 0, second epoch all 1, third all 2 etc.. ... + return iter([self.cnt for i in range(self.num_samples)]) + + def reset(self): + self.cnt = (self.cnt + 1) % self.dataset_size + + def test_config(num_samples, num_repeats, sampler): + data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) + if num_repeats is not None: + data1 = data1.repeat(num_repeats) + res = [] + for item in data1.create_dict_iterator(): + logger.info("item[image].shape[0]: {}, item[label].item(): {}" + .format(item["image"].shape[0], item["label"].item())) + res.append(map[(item["image"].shape[0], item["label"].item())]) + # print(res) + return res + + def test_generator(): + class MySampler(ds.Sampler): + def __iter__(self): + for i in range(99, -1, -1): + yield i + + data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler = MySampler()) + i = 99 + for data in data1: + assert data[0] == (np.array(i),) + i = i - 1 + + assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] + test_generator() + + sp1 = Sp1().create() + sp1.set_num_rows(5) + sp1.set_num_samples(5) + sp1.initialize() + assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) test_sampler_py_api() + test_python_sampler() \ No newline at end of file