forked from mindspore-Ecosystem/mindspore
Add python sampler support for CPP dataset
This commit is contained in:
parent
3ad73b7d71
commit
43a2e99833
|
@ -53,6 +53,7 @@
|
|||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
|
@ -415,6 +416,7 @@ void bindSamplerOps(py::module *m) {
|
|||
|
||||
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
||||
|
||||
|
@ -425,6 +427,9 @@ void bindSamplerOps(py::module *m) {
|
|||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
||||
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
||||
py::arg("replacement"));
|
||||
|
||||
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
|
||||
.def(py::init<py::object>(), py::arg("pySampler"));
|
||||
}
|
||||
|
||||
void bindInfoObjects(py::module *m) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_library(engine-datasetops-source-sampler OBJECT
|
||||
distributed_sampler.cc
|
||||
pk_sampler.cc
|
||||
python_sampler.cc
|
||||
random_sampler.cc
|
||||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
|
||||
|
||||
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (need_to_reset_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
try {
|
||||
py::object py_ret = py_sampler_instance.attr("_get_indices")();
|
||||
py::array np_sample_ids = py_ret.cast<py::array>();
|
||||
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
need_to_reset_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0");
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
try {
|
||||
py_sampler_instance.attr("_handshake")(num_rows_, num_samples_);
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
|
||||
need_to_reset_ = false;
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
try {
|
||||
py_sampler_instance.attr("reset")();
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class PythonSampler : public Sampler {
|
||||
public:
|
||||
// Constructor
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit PythonSampler(py::object py_sampler_instance,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
~PythonSampler() = default;
|
||||
|
||||
// Initialize the sampler.
|
||||
// @return Status
|
||||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer()
|
||||
|
||||
py::object py_sampler_instance; // The handle to the py_sampler python object
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
|
|
@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
|
||||
// check samples_per_buffer is properly set and doesn't overflow
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid");
|
||||
|
||||
// A call to derived class to get sample ids wrapped inside a buffer
|
||||
RETURN_IF_NOT_OK(GetNextBuffer(&db));
|
||||
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
|
||||
|
|
|
@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
|||
}
|
||||
|
||||
Status SequentialSampler::InitSampler() {
|
||||
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
return Status::OK();
|
||||
|
|
|
@ -23,7 +23,7 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset
|
|||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \
|
||||
Shuffle, zip
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler
|
||||
WeightedRandomSampler, Sampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",
|
||||
|
|
|
@ -2032,7 +2032,7 @@ class GeneratorDataset(SourceDataset):
|
|||
if self.sampler is not None and hasattr(source, "__getitem__"):
|
||||
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||
samplers.WeightedRandomSampler)):
|
||||
samplers.WeightedRandomSampler, samplers.Sampler)):
|
||||
if num_samples is None:
|
||||
num_samples = len(source)
|
||||
sampler_instance = self.sampler.create()
|
||||
|
|
|
@ -16,11 +16,90 @@
|
|||
Sampler module provides several samplers to generate sampling data from dataset.
|
||||
There are following samplers: DistributedSampler, PKSampler, RandomSampler,
|
||||
SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
|
||||
User can also define custom sampler by extending from Sampler class.
|
||||
"""
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
import numpy as np
|
||||
|
||||
class DistributedSampler():
|
||||
|
||||
class Sampler:
|
||||
"""
|
||||
Base class for user defined sampler.
|
||||
User defined sampler can be used with any existing dataset with sampler support.
|
||||
|
||||
An required _iter_() method should by overridden by user for sample index generation.
|
||||
An optional reset() method can be overridden for per repeat reset,
|
||||
|
||||
dataset_size and num_samples will be set by dataset once a dataset iterator is created.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> class ReverseSampler(ds,Sampler):
|
||||
>>> def __iter__(self):
|
||||
>>> for i in range(self.dataset_size - 1, -1, -1):
|
||||
>>> yield i
|
||||
>>>
|
||||
>>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.dataset_size = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
User defined iterator, must be overridden.
|
||||
_handshake is guaranteed to be called prior to iterator construction
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Per repeat reset callback, override this method if necessary
|
||||
"""
|
||||
|
||||
# Initialization handshake callback
|
||||
# Do not override this method!
|
||||
def _handshake(self, ds_size, num_samples):
|
||||
self.dataset_size = ds_size
|
||||
self.num_samples = num_samples
|
||||
|
||||
# Indices fetcher
|
||||
# Do not override this method!
|
||||
def _get_indices(self):
|
||||
sampler_iter = iter(self)
|
||||
ret = []
|
||||
for _ in range(self.num_samples):
|
||||
try:
|
||||
idx = next(sampler_iter)
|
||||
ret.append(idx)
|
||||
except StopIteration:
|
||||
break
|
||||
return np.array(ret)
|
||||
|
||||
# Instance fetcher
|
||||
# Do not override this method!
|
||||
def create(self):
|
||||
return cde.PythonSampler(self)
|
||||
|
||||
|
||||
class BuiltinSampler:
|
||||
"""
|
||||
Base class for BuiltinSampler.
|
||||
|
||||
User should not extend this class.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
|
||||
class DistributedSampler(BuiltinSampler):
|
||||
"""
|
||||
Sampler that access a shard of the dataset.
|
||||
|
||||
|
@ -65,7 +144,7 @@ class DistributedSampler():
|
|||
return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
|
||||
|
||||
class PKSampler():
|
||||
class PKSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples K elements for each P class in the dataset.
|
||||
|
||||
|
@ -106,7 +185,7 @@ class PKSampler():
|
|||
return cde.PKSampler(self.num_val, self.shuffle)
|
||||
|
||||
|
||||
class RandomSampler():
|
||||
class RandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements randomly.
|
||||
|
||||
|
@ -147,7 +226,7 @@ class RandomSampler():
|
|||
return cde.RandomSampler(self.replacement, self.num_samples)
|
||||
|
||||
|
||||
class SequentialSampler():
|
||||
class SequentialSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the dataset elements sequentially, same as not having a sampler.
|
||||
|
||||
|
@ -165,7 +244,7 @@ class SequentialSampler():
|
|||
return cde.SequentialSampler()
|
||||
|
||||
|
||||
class SubsetRandomSampler():
|
||||
class SubsetRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements randomly from a sequence of indices.
|
||||
|
||||
|
@ -196,7 +275,8 @@ class SubsetRandomSampler():
|
|||
def _create_for_minddataset(self):
|
||||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
|
||||
class WeightedRandomSampler():
|
||||
|
||||
class WeightedRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
||||
|
||||
|
|
|
@ -297,9 +297,7 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
||||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
|
||||
if sampler is not None and not isinstance(sampler, (
|
||||
samplers.DistributedSampler, samplers.PKSampler, samplers.RandomSampler, samplers.SequentialSampler,
|
||||
samplers.SubsetRandomSampler, samplers.WeightedRandomSampler)):
|
||||
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise ValueError("sampler is not a valid Sampler type.")
|
||||
|
||||
if sampler is not None:
|
||||
|
@ -579,11 +577,11 @@ def check_generatordataset(method):
|
|||
raise ValueError("PKSampler is not supported by GeneratorDataset")
|
||||
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||
samplers.WeightedRandomSampler)):
|
||||
samplers.WeightedRandomSampler, samplers.Sampler)):
|
||||
try:
|
||||
iter(sampler)
|
||||
except TypeError:
|
||||
raise TypeError("sampler should be either iterable or from dataset.samplers.py")
|
||||
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
import numpy as np
|
||||
|
||||
|
||||
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
|
||||
|
@ -107,8 +108,64 @@ def test_sampler_py_api():
|
|||
sampler.get_indices()
|
||||
|
||||
|
||||
def test_python_sampler():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
class Sp1(ds.Sampler):
|
||||
def __iter__(self):
|
||||
return iter([i for i in range(self.dataset_size)])
|
||||
|
||||
class Sp2(ds.Sampler):
|
||||
def __init__(self):
|
||||
super(Sp2, self).__init__()
|
||||
# at this stage, self.dataset_size and self.num_samples are not yet known
|
||||
self.cnt = 0
|
||||
|
||||
def __iter__(self): # first epoch, all 0, second epoch all 1, third all 2 etc.. ...
|
||||
return iter([self.cnt for i in range(self.num_samples)])
|
||||
|
||||
def reset(self):
|
||||
self.cnt = (self.cnt + 1) % self.dataset_size
|
||||
|
||||
def test_config(num_samples, num_repeats, sampler):
|
||||
data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
|
||||
if num_repeats is not None:
|
||||
data1 = data1.repeat(num_repeats)
|
||||
res = []
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
|
||||
.format(item["image"].shape[0], item["label"].item()))
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
# print(res)
|
||||
return res
|
||||
|
||||
def test_generator():
|
||||
class MySampler(ds.Sampler):
|
||||
def __iter__(self):
|
||||
for i in range(99, -1, -1):
|
||||
yield i
|
||||
|
||||
data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler = MySampler())
|
||||
i = 99
|
||||
for data in data1:
|
||||
assert data[0] == (np.array(i),)
|
||||
i = i - 1
|
||||
|
||||
assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||
assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||
test_generator()
|
||||
|
||||
sp1 = Sp1().create()
|
||||
sp1.set_num_rows(5)
|
||||
sp1.set_num_samples(5)
|
||||
sp1.initialize()
|
||||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
test_random_sampler_multi_iter(True)
|
||||
test_sampler_py_api()
|
||||
test_python_sampler()
|
Loading…
Reference in New Issue