Add python sampler support for CPP dataset

This commit is contained in:
Junhan Hu 2020-04-16 13:42:16 -04:00
parent 3ad73b7d71
commit 43a2e99833
11 changed files with 296 additions and 16 deletions

View File

@ -53,6 +53,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/to_float16_op.h"
@ -415,6 +416,7 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
@ -425,6 +427,9 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
py::arg("replacement"));
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
.def(py::init<py::object>(), py::arg("pySampler"));
}
void bindInfoObjects(py::module *m) {

View File

@ -1,6 +1,7 @@
add_library(engine-datasetops-source-sampler OBJECT
distributed_sampler.cc
pk_sampler.cc
python_sampler.cc
random_sampler.cc
sampler.cc
sequential_sampler.cc

View File

@ -0,0 +1,83 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include <memory>
namespace mindspore {
namespace dataset {
PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
std::shared_ptr<Tensor> sample_ids;
{
py::gil_scoped_acquire gil_acquire;
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
py::object py_ret = py_sampler_instance.attr("_get_indices")();
py::array np_sample_ids = py_ret.cast<py::array>();
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
need_to_reset_ = true;
}
return Status::OK();
}
Status PythonSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0");
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
py_sampler_instance.attr("_handshake")(num_rows_, num_samples_);
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
return Status::OK();
}
Status PythonSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
need_to_reset_ = false;
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
py_sampler_instance.attr("reset")();
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
#include <limits>
#include <memory>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace mindspore {
namespace dataset {
class PythonSampler : public Sampler {
public:
// Constructor
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
~PythonSampler() = default;
// Initialize the sampler.
// @return Status
Status InitSampler() override;
// for next epoch of sampleIds
// @return - The error code return
Status Reset() override;
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer()
py::object py_sampler_instance; // The handle to the py_sampler python object
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_

View File

@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids;
// check samples_per_buffer is properly set and doesn't overflow
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid");
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK(GetNextBuffer(&db));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch

View File

@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}
Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();

View File

@ -23,7 +23,7 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \
Shuffle, zip
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler
WeightedRandomSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",

View File

@ -2032,7 +2032,7 @@ class GeneratorDataset(SourceDataset):
if self.sampler is not None and hasattr(source, "__getitem__"):
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler)):
samplers.WeightedRandomSampler, samplers.Sampler)):
if num_samples is None:
num_samples = len(source)
sampler_instance = self.sampler.create()

View File

@ -16,11 +16,90 @@
Sampler module provides several samplers to generate sampling data from dataset.
There are following samplers: DistributedSampler, PKSampler, RandomSampler,
SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
User can also define custom sampler by extending from Sampler class.
"""
import mindspore._c_dataengine as cde
import numpy as np
class DistributedSampler():
class Sampler:
"""
Base class for user defined sampler.
User defined sampler can be used with any existing dataset with sampler support.
An required _iter_() method should by overridden by user for sample index generation.
An optional reset() method can be overridden for per repeat reset,
dataset_size and num_samples will be set by dataset once a dataset iterator is created.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> class ReverseSampler(ds,Sampler):
>>> def __iter__(self):
>>> for i in range(self.dataset_size - 1, -1, -1):
>>> yield i
>>>
>>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
"""
def __init__(self):
self.dataset_size = 0
self.num_samples = 0
def __iter__(self):
"""
User defined iterator, must be overridden.
_handshake is guaranteed to be called prior to iterator construction
"""
raise NotImplementedError
def reset(self):
"""
Per repeat reset callback, override this method if necessary
"""
# Initialization handshake callback
# Do not override this method!
def _handshake(self, ds_size, num_samples):
self.dataset_size = ds_size
self.num_samples = num_samples
# Indices fetcher
# Do not override this method!
def _get_indices(self):
sampler_iter = iter(self)
ret = []
for _ in range(self.num_samples):
try:
idx = next(sampler_iter)
ret.append(idx)
except StopIteration:
break
return np.array(ret)
# Instance fetcher
# Do not override this method!
def create(self):
return cde.PythonSampler(self)
class BuiltinSampler:
"""
Base class for BuiltinSampler.
User should not extend this class.
"""
def __init__(self):
pass
def create(self):
pass
class DistributedSampler(BuiltinSampler):
"""
Sampler that access a shard of the dataset.
@ -65,7 +144,7 @@ class DistributedSampler():
return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
class PKSampler():
class PKSampler(BuiltinSampler):
"""
Samples K elements for each P class in the dataset.
@ -106,7 +185,7 @@ class PKSampler():
return cde.PKSampler(self.num_val, self.shuffle)
class RandomSampler():
class RandomSampler(BuiltinSampler):
"""
Samples the elements randomly.
@ -147,7 +226,7 @@ class RandomSampler():
return cde.RandomSampler(self.replacement, self.num_samples)
class SequentialSampler():
class SequentialSampler(BuiltinSampler):
"""
Samples the dataset elements sequentially, same as not having a sampler.
@ -165,7 +244,7 @@ class SequentialSampler():
return cde.SequentialSampler()
class SubsetRandomSampler():
class SubsetRandomSampler(BuiltinSampler):
"""
Samples the elements randomly from a sequence of indices.
@ -196,7 +275,8 @@ class SubsetRandomSampler():
def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)
class WeightedRandomSampler():
class WeightedRandomSampler(BuiltinSampler):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).

View File

@ -297,9 +297,7 @@ def check_sampler_shuffle_shard_options(param_dict):
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
if sampler is not None and not isinstance(sampler, (
samplers.DistributedSampler, samplers.PKSampler, samplers.RandomSampler, samplers.SequentialSampler,
samplers.SubsetRandomSampler, samplers.WeightedRandomSampler)):
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise ValueError("sampler is not a valid Sampler type.")
if sampler is not None:
@ -579,11 +577,11 @@ def check_generatordataset(method):
raise ValueError("PKSampler is not supported by GeneratorDataset")
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler)):
samplers.WeightedRandomSampler, samplers.Sampler)):
try:
iter(sampler)
except TypeError:
raise TypeError("sampler should be either iterable or from dataset.samplers.py")
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
return method(*args, **kwargs)

View File

@ -14,6 +14,7 @@
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
import numpy as np
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
@ -107,8 +108,64 @@ def test_sampler_py_api():
sampler.get_indices()
def test_python_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
class Sp1(ds.Sampler):
def __iter__(self):
return iter([i for i in range(self.dataset_size)])
class Sp2(ds.Sampler):
def __init__(self):
super(Sp2, self).__init__()
# at this stage, self.dataset_size and self.num_samples are not yet known
self.cnt = 0
def __iter__(self): # first epoch, all 0, second epoch all 1, third all 2 etc.. ...
return iter([self.cnt for i in range(self.num_samples)])
def reset(self):
self.cnt = (self.cnt + 1) % self.dataset_size
def test_config(num_samples, num_repeats, sampler):
data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
if num_repeats is not None:
data1 = data1.repeat(num_repeats)
res = []
for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())])
# print(res)
return res
def test_generator():
class MySampler(ds.Sampler):
def __iter__(self):
for i in range(99, -1, -1):
yield i
data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler = MySampler())
i = 99
for data in data1:
assert data[0] == (np.array(i),)
i = i - 1
assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
test_generator()
sp1 = Sp1().create()
sp1.set_num_rows(5)
sp1.set_num_samples(5)
sp1.initialize()
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
test_random_sampler_multi_iter(True)
test_sampler_py_api()
test_python_sampler()