forked from mindspore-Ecosystem/mindspore
!183 Mindspore.dataset CPP sampler for GeneratorDataset
Merge pull request !183 from JunhanHu/cpp_sampler
This commit is contained in:
commit
cf026096a6
|
@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase
|
||||||
std::string key = py::str(arg.first);
|
std::string key = py::str(arg.first);
|
||||||
py::handle value = arg.second;
|
py::handle value = arg.second;
|
||||||
if (!value.is_none()) {
|
if (!value.is_none()) {
|
||||||
if (key == "generator_function") {
|
if (key == "source") {
|
||||||
py::object obj = py::cast(&value);
|
py::object obj = py::cast(&value);
|
||||||
if (!py::isinstance<py::function>(obj)) {
|
if (!py::isinstance<py::function>(obj)) {
|
||||||
std::string err_msg = "Error: generator is invalid or not set.";
|
std::string err_msg = "Error: generator is invalid or not set.";
|
||||||
|
|
|
@ -388,7 +388,16 @@ void bindTensorOps4(py::module *m) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void bindSamplerOps(py::module *m) {
|
void bindSamplerOps(py::module *m) {
|
||||||
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler");
|
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
|
||||||
|
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
||||||
|
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
||||||
|
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
|
||||||
|
.def("get_indices", [](Sampler &self) {
|
||||||
|
py::array ret;
|
||||||
|
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||||
|
return ret;
|
||||||
|
});
|
||||||
|
|
||||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||||
|
|
||||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
||||||
|
|
|
@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
|
||||||
|
|
||||||
// return data as numpy, should return status
|
// return data as numpy, should return status
|
||||||
Status Tensor::GetDataAsNumpy(py::array *data) {
|
Status Tensor::GetDataAsNumpy(py::array *data) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(data_);
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(data);
|
||||||
if (type_ == DataType::DE_BOOL) {
|
if (type_ == DataType::DE_BOOL) {
|
||||||
*data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_));
|
*data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_));
|
||||||
} else if (type_ == DataType::DE_INT8) {
|
} else if (type_ == DataType::DE_INT8) {
|
||||||
|
|
|
@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() {
|
||||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1)));
|
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1)));
|
||||||
TaskManager::FindMe()->Post();
|
TaskManager::FindMe()->Post();
|
||||||
RETURN_IF_NOT_OK(ParseImageAttrInfo());
|
RETURN_IF_NOT_OK(ParseImageAttrInfo());
|
||||||
RETURN_IF_NOT_OK(sampler_->Init(this));
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -240,7 +240,7 @@ Status CifarOp::Reset() {
|
||||||
|
|
||||||
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
||||||
Status CifarOp::InitSampler() {
|
Status CifarOp::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(sampler_->Init(this));
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() {
|
||||||
|
|
||||||
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
||||||
Status ImageFolderOp::InitSampler() {
|
Status ImageFolderOp::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(sampler_->Init(this));
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -254,7 +254,7 @@ Status ManifestOp::Reset() {
|
||||||
|
|
||||||
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
||||||
Status ManifestOp::InitSampler() {
|
Status ManifestOp::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(sampler_->Init(this));
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -205,7 +205,7 @@ Status MnistOp::Reset() {
|
||||||
|
|
||||||
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
|
||||||
Status MnistOp::InitSampler() {
|
Status MnistOp::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(sampler_->Init(this));
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
|
||||||
num_devices_(num_dev),
|
num_devices_(num_dev),
|
||||||
shuffle_(shuffle) {}
|
shuffle_(shuffle) {}
|
||||||
|
|
||||||
Status DistributedSampler::Init(const RandomAccessOp *op) {
|
Status DistributedSampler::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(Sampler::Init(op));
|
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
||||||
"fail to init DistributedSampler");
|
"fail to init DistributedSampler");
|
||||||
rnd_.seed(seed_++);
|
rnd_.seed(seed_++);
|
||||||
|
|
|
@ -41,10 +41,8 @@ class DistributedSampler : public Sampler {
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||||
|
|
||||||
// first handshake between StorageOp and Sampler
|
// Init sampler, called by base class or python
|
||||||
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
|
Status InitSampler() override;
|
||||||
// @return
|
|
||||||
Status Init(const RandomAccessOp *) override;
|
|
||||||
|
|
||||||
// for next epoch of sampleIds
|
// for next epoch of sampleIds
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
|
|
|
@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
|
||||||
num_pk_samples_(0),
|
num_pk_samples_(0),
|
||||||
samples_per_class_(val) {}
|
samples_per_class_(val) {}
|
||||||
|
|
||||||
Status PKSampler::Init(const RandomAccessOp *op) {
|
Status PKSampler::InitSampler() {
|
||||||
RETURN_UNEXPECTED_IF_NULL(op);
|
|
||||||
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
|
|
||||||
labels_.reserve(label_to_ids_.size());
|
labels_.reserve(label_to_ids_.size());
|
||||||
for (const auto &pair : label_to_ids_) {
|
for (const auto &pair : label_to_ids_) {
|
||||||
if (pair.second.empty() == false) {
|
if (pair.second.empty() == false) {
|
||||||
|
@ -79,5 +77,13 @@ Status PKSampler::Reset() {
|
||||||
rnd_.seed(seed_++);
|
rnd_.seed(seed_++);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(op);
|
||||||
|
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
|
||||||
|
RETURN_IF_NOT_OK(InitSampler());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED
|
||||||
// first handshake between StorageOp and Sampler
|
// first handshake between StorageOp and Sampler
|
||||||
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
|
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
|
||||||
// @return
|
// @return
|
||||||
Status Init(const RandomAccessOp *op) override;
|
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
|
||||||
|
|
||||||
|
// init sampler, to be called by python or Handshake
|
||||||
|
Status InitSampler() override;
|
||||||
|
|
||||||
// for next epoch of sampleIds
|
// for next epoch of sampleIds
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
|
|
|
@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RandomSampler::Init(const RandomAccessOp *op) {
|
Status RandomSampler::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(Sampler::Init(op));
|
|
||||||
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
|
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "Fail to init RandomSampler");
|
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
|
||||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||||
if (replacement_ == false) {
|
if (replacement_ == false) {
|
||||||
shuffled_ids_.reserve(num_rows_);
|
shuffled_ids_.reserve(num_rows_);
|
||||||
|
|
|
@ -42,10 +42,8 @@ class RandomSampler : public Sampler {
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||||
|
|
||||||
// first handshake between StorageOp and Sampler
|
// meant to be called by base class or python
|
||||||
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
|
Status InitSampler() override;
|
||||||
// @return
|
|
||||||
Status Init(const RandomAccessOp *op) override;
|
|
||||||
|
|
||||||
// for next epoch of sampleIds
|
// for next epoch of sampleIds
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
|
|
|
@ -20,12 +20,13 @@ namespace dataset {
|
||||||
Sampler::Sampler(int64_t samples_per_buffer)
|
Sampler::Sampler(int64_t samples_per_buffer)
|
||||||
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
||||||
|
|
||||||
Status Sampler::Init(const RandomAccessOp *op) {
|
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr && samples_per_buffer_ > 0, "Fail to init Sampler()\n");
|
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
|
||||||
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
|
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
|
||||||
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
||||||
// It's up to the derived class to check the validity of the two args
|
// It's up to the derived class to check the validity of the two args
|
||||||
// Because some sampler only needs one of the arg (weighted_random_sampler)
|
// Because some sampler only needs one of the arg (weighted_random_sampler)
|
||||||
|
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
|
||||||
(void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets!
|
(void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets!
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Sampler::GetAllIdsThenReset(py::array *data) {
|
||||||
|
std::unique_ptr<DataBuffer> db;
|
||||||
|
std::shared_ptr<Tensor> sample_ids;
|
||||||
|
|
||||||
|
// check samples_per_buffer is properly set and doesn't overflow
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid");
|
||||||
|
|
||||||
|
// A call to derived class to get sample ids wrapped inside a buffer
|
||||||
|
RETURN_IF_NOT_OK(GetNextBuffer(&db));
|
||||||
|
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
|
||||||
|
RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0));
|
||||||
|
// check this buffer is not a ctrl buffer
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received");
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil_acquire;
|
||||||
|
if (Py_IsInitialized() == 0) {
|
||||||
|
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data));
|
||||||
|
} catch (const std::runtime_error &e) {
|
||||||
|
return Status(StatusCode::kPyFuncException, e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
|
||||||
|
RETURN_IF_NOT_OK(GetNextBuffer(&db));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received");
|
||||||
|
// Reset Sampler since this is the end of the epoch
|
||||||
|
RETURN_IF_NOT_OK(Reset());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Sampler::SetNumSamples(int64_t num_samples) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0");
|
||||||
|
num_samples_ = num_samples;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0");
|
||||||
|
num_rows_ = num_rows;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -78,14 +78,26 @@ class Sampler : public DatasetOp {
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0;
|
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0;
|
||||||
|
|
||||||
|
// return all ids in one epoch as a numpy array, then call reset
|
||||||
|
Status GetAllIdsThenReset(py::array *data);
|
||||||
|
|
||||||
// for next epoch of sampleIds
|
// for next epoch of sampleIds
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
Status Reset() override = 0;
|
Status Reset() override = 0;
|
||||||
|
|
||||||
// first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples
|
// setter function for num_rows_
|
||||||
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
|
Status SetNumRowsInDataset(int64_t num_rows);
|
||||||
|
|
||||||
|
// setter function for num_samples_
|
||||||
|
Status SetNumSamples(int64_t num_samples);
|
||||||
|
|
||||||
|
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
|
||||||
|
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
|
||||||
// @return
|
// @return
|
||||||
virtual Status Init(const RandomAccessOp *op);
|
virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
|
||||||
|
|
||||||
|
// initialize sampler and perform checks on certain vars
|
||||||
|
virtual Status InitSampler() { return Status::OK(); }
|
||||||
|
|
||||||
// Not meant to be called
|
// Not meant to be called
|
||||||
// @return
|
// @return
|
||||||
|
|
|
@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SequentialSampler::Init(const RandomAccessOp *op) {
|
Status SequentialSampler::InitSampler() {
|
||||||
RETURN_UNEXPECTED_IF_NULL(op);
|
|
||||||
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
|
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
|
||||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -32,10 +32,8 @@ class SequentialSampler : public Sampler {
|
||||||
// Destructor.
|
// Destructor.
|
||||||
~SequentialSampler() = default;
|
~SequentialSampler() = default;
|
||||||
|
|
||||||
// Initialize the sampler.
|
// init sampler, called by python
|
||||||
// @param op
|
Status InitSampler() override;
|
||||||
// @return Status
|
|
||||||
Status Init(const RandomAccessOp *op) override;
|
|
||||||
|
|
||||||
// for next epoch of sampleIds
|
// for next epoch of sampleIds
|
||||||
// @return - The error code return
|
// @return - The error code return
|
||||||
|
|
|
@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
|
||||||
: Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
|
: Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
|
||||||
|
|
||||||
// Initialized this Sampler.
|
// Initialized this Sampler.
|
||||||
Status SubsetRandomSampler::Init(const RandomAccessOp *op) {
|
Status SubsetRandomSampler::InitSampler() {
|
||||||
// Calling base class init.
|
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
|
||||||
RETURN_IF_NOT_OK(Sampler::Init(op));
|
|
||||||
|
|
||||||
// Initialize random generator with seed from config manager
|
// Initialize random generator with seed from config manager
|
||||||
rand_gen_.seed(GetSeed());
|
rand_gen_.seed(GetSeed());
|
||||||
|
|
|
@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler {
|
||||||
~SubsetRandomSampler() = default;
|
~SubsetRandomSampler() = default;
|
||||||
|
|
||||||
// Initialize the sampler.
|
// Initialize the sampler.
|
||||||
// @param op (Not used in this sampler)
|
|
||||||
// @return Status
|
// @return Status
|
||||||
Status Init(const RandomAccessOp *op) override;
|
Status InitSampler() override;
|
||||||
|
|
||||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||||
// @return Status
|
// @return Status
|
||||||
|
|
|
@ -29,21 +29,21 @@ namespace dataset {
|
||||||
// Constructor.
|
// Constructor.
|
||||||
WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement,
|
WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement,
|
||||||
int64_t samples_per_buffer)
|
int64_t samples_per_buffer)
|
||||||
: Sampler(samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), buffer_id_(0) {
|
: Sampler(samples_per_buffer),
|
||||||
num_samples_ = num_samples; // this variable is defined in base class sampler
|
weights_(weights),
|
||||||
}
|
replacement_(replacement),
|
||||||
|
sample_id_(0),
|
||||||
|
buffer_id_(0),
|
||||||
|
user_num_samples_(num_samples) {}
|
||||||
|
|
||||||
// Initialized this Sampler.
|
// Initialized this Sampler.
|
||||||
Status WeightedRandomSampler::Init(const RandomAccessOp *op) {
|
Status WeightedRandomSampler::InitSampler() {
|
||||||
RETURN_UNEXPECTED_IF_NULL(op);
|
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
|
||||||
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
|
||||||
|
|
||||||
// Initialize random generator with seed from config manager
|
// Initialize random generator with seed from config manager
|
||||||
rand_gen_.seed(GetSeed());
|
rand_gen_.seed(GetSeed());
|
||||||
|
|
||||||
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_;
|
samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_;
|
||||||
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init WeightedRandomSampler");
|
|
||||||
|
|
||||||
if (!replacement_) {
|
if (!replacement_) {
|
||||||
exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
|
exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
|
||||||
|
@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Partial sort the first `numSamples` elements.
|
// Partial sort the first `numSamples` elements.
|
||||||
std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
|
std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end());
|
||||||
for (int64_t i = 0; i < num_samples_; i++) {
|
for (int64_t i = 0; i < user_num_samples_; i++) {
|
||||||
onepass_ids_.push_back(val_idx[i].second);
|
onepass_ids_.push_back(val_idx[i].second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
||||||
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
|
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) {
|
||||||
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
|
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sample_id_ == num_samples_) {
|
if (sample_id_ == user_num_samples_) {
|
||||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||||
} else {
|
} else {
|
||||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||||
|
@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
||||||
|
|
||||||
int64_t last_id = sample_id_ + samples_per_buffer_;
|
int64_t last_id = sample_id_ + samples_per_buffer_;
|
||||||
// Handling the return all samples at once, and when last draw is not a full batch.
|
// Handling the return all samples at once, and when last draw is not a full batch.
|
||||||
if (last_id > num_samples_) {
|
if (last_id > user_num_samples_) {
|
||||||
last_id = num_samples_;
|
last_id = user_num_samples_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate tensor.
|
// Allocate tensor.
|
||||||
|
|
|
@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler {
|
||||||
// Initialize the sampler.
|
// Initialize the sampler.
|
||||||
// @param op (Not used in this sampler)
|
// @param op (Not used in this sampler)
|
||||||
// @return Status
|
// @return Status
|
||||||
Status Init(const RandomAccessOp *op) override;
|
Status InitSampler() override;
|
||||||
|
|
||||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||||
Status Reset() override;
|
Status Reset() override;
|
||||||
|
@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler {
|
||||||
// Random engine and device
|
// Random engine and device
|
||||||
std::mt19937 rand_gen_;
|
std::mt19937 rand_gen_;
|
||||||
|
|
||||||
|
// num_samples from user
|
||||||
|
int64_t user_num_samples_;
|
||||||
|
|
||||||
// Discrete distribution for generating weighted random numbers with replacement.
|
// Discrete distribution for generating weighted random numbers with replacement.
|
||||||
std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_;
|
std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_;
|
||||||
|
|
||||||
|
|
|
@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status VOCOp::InitSampler() {
|
Status VOCOp::InitSampler() {
|
||||||
RETURN_IF_NOT_OK(sampler_->Init(this));
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1758,14 +1758,70 @@ class MindDataset(SourceDataset):
|
||||||
return num_rows
|
return num_rows
|
||||||
|
|
||||||
|
|
||||||
def ds_fn(dataset):
|
def _iter_fn(dataset, num_samples):
|
||||||
for val in dataset:
|
"""
|
||||||
# convert output tensors to ndarrays
|
Generator function wrapper for iterable dataset
|
||||||
yield tuple([np.array(x) for x in val])
|
"""
|
||||||
|
if num_samples is not None:
|
||||||
|
ds_iter = iter(dataset)
|
||||||
|
for _ in range(num_samples):
|
||||||
|
try:
|
||||||
|
val = next(ds_iter)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
# convert output tensors to ndarrays
|
||||||
|
yield tuple([np.array(x) for x in val])
|
||||||
|
else:
|
||||||
|
for val in dataset:
|
||||||
|
# convert output tensors to ndarrays
|
||||||
|
yield tuple([np.array(x) for x in val])
|
||||||
|
|
||||||
|
|
||||||
def sampler_fn(sampler, dataset):
|
def _generator_fn(generator, num_samples):
|
||||||
for i in sampler:
|
"""
|
||||||
|
Generator function wrapper for generator function dataset
|
||||||
|
"""
|
||||||
|
if num_samples is not None:
|
||||||
|
gen_iter = generator()
|
||||||
|
for _ in range(num_samples):
|
||||||
|
try:
|
||||||
|
val = next(gen_iter)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
yield val
|
||||||
|
else:
|
||||||
|
gen_iter = generator()
|
||||||
|
for val in gen_iter:
|
||||||
|
yield val
|
||||||
|
|
||||||
|
|
||||||
|
def _py_sampler_fn(sampler, num_samples, dataset):
|
||||||
|
"""
|
||||||
|
Generator function wrapper for mappable dataset with python sampler
|
||||||
|
"""
|
||||||
|
if num_samples is not None:
|
||||||
|
sampler_iter = iter(sampler)
|
||||||
|
for _ in range(num_samples):
|
||||||
|
try:
|
||||||
|
idx = next(sampler_iter)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
val = dataset[idx]
|
||||||
|
# convert output tensors to ndarrays
|
||||||
|
yield tuple([np.array(x) for x in val])
|
||||||
|
else:
|
||||||
|
for i in sampler:
|
||||||
|
val = dataset[i]
|
||||||
|
# convert output tensors to ndarrays
|
||||||
|
yield tuple([np.array(x) for x in val])
|
||||||
|
|
||||||
|
|
||||||
|
def _cpp_sampler_fn(sampler, dataset):
|
||||||
|
"""
|
||||||
|
Generator function wrapper for mappable dataset with cpp sampler
|
||||||
|
"""
|
||||||
|
indices = sampler.get_indices()
|
||||||
|
for i in indices:
|
||||||
val = dataset[i]
|
val = dataset[i]
|
||||||
# convert output tensors to ndarrays
|
# convert output tensors to ndarrays
|
||||||
yield tuple([np.array(x) for x in val])
|
yield tuple([np.array(x) for x in val])
|
||||||
|
@ -1773,49 +1829,122 @@ def sampler_fn(sampler, dataset):
|
||||||
|
|
||||||
class GeneratorDataset(SourceDataset):
|
class GeneratorDataset(SourceDataset):
|
||||||
"""
|
"""
|
||||||
A source dataset that generate data from calling generator function each epoch.
|
A source dataset that generate data from python by invoking python data source each epoch.
|
||||||
|
|
||||||
|
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
|
||||||
|
below shows what input args are allowed and their expected behavior.
|
||||||
|
|
||||||
|
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
|
||||||
|
:widths: 25 25 50
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Parameter 'sampler'
|
||||||
|
- Parameter 'shuffle'
|
||||||
|
- Expected Order Behavior
|
||||||
|
* - None
|
||||||
|
- None
|
||||||
|
- random order
|
||||||
|
* - None
|
||||||
|
- True
|
||||||
|
- random order
|
||||||
|
* - None
|
||||||
|
- False
|
||||||
|
- sequential order
|
||||||
|
* - Sampler object
|
||||||
|
- None
|
||||||
|
- order defined by sampler
|
||||||
|
* - Sampler object
|
||||||
|
- True
|
||||||
|
- not allowed
|
||||||
|
* - Sampler object
|
||||||
|
- False
|
||||||
|
- not allowed
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
generator_function (callable):
|
source (Callable/Iterable/Random Accessible):
|
||||||
A callable object that returns an Generator object that supports the iter() protocol.
|
A generator callable object, an iterable python object or a random accessible python object.
|
||||||
Generator object is required to return a tuple of numpy array as a row of the dataset on next().
|
Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
|
||||||
|
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
|
||||||
|
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
|
||||||
|
source[idx].
|
||||||
column_names (list[str]): List of column names of the dataset.
|
column_names (list[str]): List of column names of the dataset.
|
||||||
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
|
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
|
||||||
If provided, sanity check will be performed on generator output.
|
If provided, sanity check will be performed on generator output.
|
||||||
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
|
schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
|
||||||
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
|
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
|
||||||
|
num_samples (int, optional): The number of samples to be included in the dataset
|
||||||
|
(default=None, all images).
|
||||||
|
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
|
||||||
|
(default=None, expected order behavior shown in the table).
|
||||||
|
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
|
||||||
|
required.
|
||||||
|
(default=None, expected order behavior shown in the table).
|
||||||
|
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
|
||||||
|
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
|
||||||
|
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
|
||||||
|
when num_shards is also specified. Random accessible input is required.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataengine as de
|
||||||
>>> # 1) generator function that generates multi-dimensional data
|
>>> # 1) Multidimensional generator function as callable input
|
||||||
>>> def generator_md():
|
>>> def generator_md():
|
||||||
>>> for i in range(64):
|
>>> for i in range(64):
|
||||||
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
|
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
|
||||||
>>> # create multi_dimension_generator_dataset with GeneratorMD() and column name "multi_dimensional_data"
|
>>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
|
||||||
>>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"])
|
>>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"])
|
||||||
>>> # 2) generator function that generates multi-columns data
|
>>> # 2) Multi-column generator function as callable input
|
||||||
>>> def generator_mc(maxid = 64):
|
>>> def generator_mc(maxid = 64):
|
||||||
>>> for i in range(maxid):
|
>>> for i in range(maxid):
|
||||||
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
|
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
|
||||||
>>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2"
|
>>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
|
||||||
>>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"])
|
>>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"])
|
||||||
|
>>> # 3) Iterable dataset as iterable input
|
||||||
|
>>> class MyIterable():
|
||||||
|
>>> def __iter__(self):
|
||||||
|
>>> return # User implementation
|
||||||
|
>>> # create iterable_generator_dataset with MyIterable object
|
||||||
|
>>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
|
||||||
|
>>> # 4) Random accessible dataset as Random accessible input
|
||||||
|
>>> class MyRA():
|
||||||
|
>>> def __getitem__(self, index):
|
||||||
|
>>> return # User implementation
|
||||||
|
>>> # create ra_generator_dataset with MyRA object
|
||||||
|
>>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
|
||||||
|
>>> # List/Dict/Tuple is also random accessible
|
||||||
|
>>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
|
||||||
|
>>> # 5) Built-in Sampler
|
||||||
|
>>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
|
||||||
|
>>>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_generatordataset
|
@check_generatordataset
|
||||||
def __init__(self, generator_function, column_names, column_types=None, prefetch_size=None, sampler=None):
|
def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1,
|
||||||
super().__init__(1)
|
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||||
if sampler is not None:
|
super().__init__(num_parallel_workers)
|
||||||
self.generator_function = (lambda: sampler_fn(sampler, generator_function))
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||||
|
if self.sampler is not None and hasattr(source, "__getitem__"):
|
||||||
|
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||||
|
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||||
|
samplers.WeightedRandomSampler)):
|
||||||
|
if num_samples is None:
|
||||||
|
num_samples = len(source)
|
||||||
|
sampler_instance = self.sampler.create()
|
||||||
|
sampler_instance.set_num_rows(len(source))
|
||||||
|
sampler_instance.set_num_samples(num_samples)
|
||||||
|
sampler_instance.initialize()
|
||||||
|
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
|
||||||
|
else:
|
||||||
|
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# test to see if generator_function is iterable
|
iter(source)
|
||||||
iter(generator_function)
|
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# generator_function was not iterable, assume it is a function
|
# Use generator function if input callable
|
||||||
self.generator_function = generator_function
|
self.source = (lambda: _generator_fn(source, num_samples))
|
||||||
else:
|
else:
|
||||||
# generator_function was iterable, build a function around it
|
# Use iterator function if input is iterable
|
||||||
self.generator_function = (lambda: ds_fn(generator_function))
|
# Random accessible input is also iterable
|
||||||
|
self.source = (lambda: _iter_fn(source, num_samples))
|
||||||
|
|
||||||
self.column_names = column_names
|
self.column_names = column_names
|
||||||
|
|
||||||
|
@ -1823,17 +1952,12 @@ class GeneratorDataset(SourceDataset):
|
||||||
self.column_types = mstypelist_to_detypelist(column_types)
|
self.column_types = mstypelist_to_detypelist(column_types)
|
||||||
else:
|
else:
|
||||||
self.column_types = column_types
|
self.column_types = column_types
|
||||||
self.distribution = ""
|
|
||||||
self.prefetch_size = prefetch_size
|
|
||||||
self.sampler = sampler
|
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
args["generator_function"] = self.generator_function
|
args["source"] = self.source
|
||||||
args["column_names"] = self.column_names
|
args["column_names"] = self.column_names
|
||||||
args["column_types"] = self.column_types
|
args["column_types"] = self.column_types
|
||||||
args["prefetch_size"] = self.prefetch_size
|
|
||||||
args["sampler"] = self.sampler
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
|
|
@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
|
|
||||||
class DistributedSampler():
|
class DistributedSampler():
|
||||||
"""
|
"""
|
||||||
Sampler that access a shard of the dataset.
|
Sampler that access a shard of the dataset.
|
||||||
|
|
|
@ -543,28 +543,48 @@ def check_generatordataset(method):
|
||||||
def new_method(*args, **kwargs):
|
def new_method(*args, **kwargs):
|
||||||
param_dict = make_param_dict(method, args, kwargs)
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
|
|
||||||
nreq_param_int = ['prefetch_size']
|
|
||||||
nreq_param_list = ['column_names', 'column_types']
|
|
||||||
|
|
||||||
# check generator_function; required argument
|
# check generator_function; required argument
|
||||||
generator_function = param_dict.get('generator_function')
|
source = param_dict.get('source')
|
||||||
if generator_function is None:
|
if source is None:
|
||||||
raise ValueError("generator_function is not provided.")
|
raise ValueError("source is not provided.")
|
||||||
|
if not callable(source):
|
||||||
|
try:
|
||||||
|
iter(source)
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError("source should be callable, iterable or random accessible")
|
||||||
|
|
||||||
# check column_names; required argument
|
# check column_names; required argument
|
||||||
column_names = param_dict.get('column_names')
|
column_names = param_dict.get('column_names')
|
||||||
if column_names is None:
|
if column_names is None:
|
||||||
raise ValueError("column_names is not provided.")
|
raise ValueError("column_names is not provided.")
|
||||||
|
|
||||||
# check prefetch_size range
|
# check optional argument
|
||||||
prefetch_size = param_dict.get('prefetch_size')
|
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
||||||
if prefetch_size is not None and (prefetch_size <= 0 or prefetch_size > 1024):
|
|
||||||
raise ValueError("prefetch_size exceeds the boundary.")
|
|
||||||
|
|
||||||
check_param_type(nreq_param_int, param_dict, int)
|
check_param_type(nreq_param_int, param_dict, int)
|
||||||
|
nreq_param_list = ["column_types"]
|
||||||
check_param_type(nreq_param_list, param_dict, list)
|
check_param_type(nreq_param_list, param_dict, list)
|
||||||
|
|
||||||
|
num_shards = param_dict.get("num_shards")
|
||||||
|
shard_id = param_dict.get("shard_id")
|
||||||
|
if (num_shards is None) != (shard_id is None):
|
||||||
|
# These two parameters appear together.
|
||||||
|
raise ValueError("num_shards and shard_id need to be passed in together")
|
||||||
|
if num_shards is not None:
|
||||||
|
if shard_id >= num_shards:
|
||||||
|
raise ValueError("shard_id should be less than num_shards")
|
||||||
|
|
||||||
|
sampler = param_dict.get("sampler")
|
||||||
|
if sampler is not None:
|
||||||
|
if isinstance(sampler, samplers.PKSampler):
|
||||||
|
raise ValueError("PKSampler is not supported by GeneratorDataset")
|
||||||
|
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||||
|
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||||
|
samplers.WeightedRandomSampler)):
|
||||||
|
try:
|
||||||
|
iter(sampler)
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError("sampler should be either iterable or from dataset.samplers.py")
|
||||||
|
|
||||||
return method(*args, **kwargs)
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
|
@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
|
||||||
std::shared_ptr<Tensor> tensor;
|
std::shared_ptr<Tensor> tensor;
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < 6; i++) {
|
||||||
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true));
|
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true));
|
||||||
sampler->Init(&mock);
|
sampler->HandshakeRandomAccessOp(&mock);
|
||||||
sampler->GetNextBuffer(&db);
|
sampler->GetNextBuffer(&db);
|
||||||
db->GetTensor(&tensor, 0, 0);
|
db->GetTensor(&tensor, 0, 0);
|
||||||
MS_LOG(DEBUG) << (*tensor);
|
MS_LOG(DEBUG) << (*tensor);
|
||||||
|
@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
|
||||||
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3);
|
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3);
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
std::shared_ptr<Tensor> tensor;
|
std::shared_ptr<Tensor> tensor;
|
||||||
sampler->Init(&mock);
|
sampler->HandshakeRandomAccessOp(&mock);
|
||||||
sampler->GetNextBuffer(&db);
|
sampler->GetNextBuffer(&db);
|
||||||
db->GetTensor(&tensor, 0, 0);
|
db->GetTensor(&tensor, 0, 0);
|
||||||
EXPECT_TRUE((*tensor) == (*label1));
|
EXPECT_TRUE((*tensor) == (*label1));
|
||||||
|
|
|
@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
|
||||||
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
||||||
SubsetRandomSampler sampler(in);
|
SubsetRandomSampler sampler(in);
|
||||||
|
|
||||||
DummyRandomAccessOp dummy_random_access_op(5);
|
DummyRandomAccessOp dummyRandomAccessOp(5);
|
||||||
sampler.Init(&dummy_random_access_op);
|
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
|
||||||
std::vector<int64_t> input(total_samples, 1);
|
std::vector<int64_t> input(total_samples, 1);
|
||||||
SubsetRandomSampler sampler(input, samples_per_buffer);
|
SubsetRandomSampler sampler(input, samples_per_buffer);
|
||||||
|
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
sampler.Init(&dummy_random_access_op);
|
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
|
||||||
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
||||||
SubsetRandomSampler sampler(in);
|
SubsetRandomSampler sampler(in);
|
||||||
|
|
||||||
DummyRandomAccessOp dummy_random_access_op(5);
|
DummyRandomAccessOp dummyRandomAccessOp(5);
|
||||||
sampler.Init(&dummy_random_access_op);
|
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
|
|
@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
|
||||||
|
|
||||||
// create sampler with replacement = true
|
// create sampler with replacement = true
|
||||||
WeightedRandomSampler m_sampler(weights, num_samples, true);
|
WeightedRandomSampler m_sampler(weights, num_samples, true);
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
m_sampler.Init(&dummy_random_access_op);
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
|
||||||
|
|
||||||
// create sampler with replacement = replacement
|
// create sampler with replacement = replacement
|
||||||
WeightedRandomSampler m_sampler(weights, num_samples, false);
|
WeightedRandomSampler m_sampler(weights, num_samples, false);
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
m_sampler.Init(&dummy_random_access_op);
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
|
||||||
|
|
||||||
// create sampler with replacement = replacement
|
// create sampler with replacement = replacement
|
||||||
WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer);
|
WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer);
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
m_sampler.Init(&dummy_random_access_op);
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
|
||||||
|
|
||||||
// create sampler with replacement = replacement
|
// create sampler with replacement = replacement
|
||||||
WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer);
|
WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer);
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
m_sampler.Init(&dummy_random_access_op);
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
||||||
|
|
||||||
// create sampler with replacement = true
|
// create sampler with replacement = true
|
||||||
WeightedRandomSampler m_sampler(weights, num_samples, true);
|
WeightedRandomSampler m_sampler(weights, num_samples, true);
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
m_sampler.Init(&dummy_random_access_op);
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
||||||
|
|
||||||
// create sampler with replacement = true
|
// create sampler with replacement = true
|
||||||
WeightedRandomSampler m_sampler(weights, num_samples, false);
|
WeightedRandomSampler m_sampler(weights, num_samples, false);
|
||||||
DummyRandomAccessOp dummy_random_access_op(total_samples);
|
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||||
m_sampler.Init(&dummy_random_access_op);
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
|
|
@ -439,6 +439,74 @@ def test_case_error_4():
|
||||||
assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
|
assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_sampler():
|
||||||
|
source = [(np.array([x]),) for x in range(64)]
|
||||||
|
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
|
||||||
|
i = 0
|
||||||
|
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([i])
|
||||||
|
assert np.array_equal(data["data"], golden)
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_sampler():
|
||||||
|
source = [(np.array([x]),) for x in range(64)]
|
||||||
|
ds1 = ds.GeneratorDataset(source, ["data"], shuffle = True)
|
||||||
|
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_distributed_sampler():
|
||||||
|
source = [(np.array([x]),) for x in range(64)]
|
||||||
|
for sid in range(8):
|
||||||
|
ds1 = ds.GeneratorDataset(source, ["data"], shuffle = False, num_shards=8, shard_id=sid)
|
||||||
|
i = sid
|
||||||
|
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([i])
|
||||||
|
assert np.array_equal(data["data"], golden)
|
||||||
|
i = i + 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_samples():
|
||||||
|
source = [(np.array([x]),) for x in range(64)]
|
||||||
|
num_samples = 32
|
||||||
|
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples = num_samples)
|
||||||
|
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples = num_samples)
|
||||||
|
ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
count = count + 1
|
||||||
|
assert count == num_samples
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for _ in ds2.create_dict_iterator():
|
||||||
|
count = count + 1
|
||||||
|
assert count == num_samples
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for _ in ds3.create_dict_iterator():
|
||||||
|
count = count + 1
|
||||||
|
assert count == num_samples
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_samples_underflow():
|
||||||
|
source = [(np.array([x]),) for x in range(64)]
|
||||||
|
num_samples = 256
|
||||||
|
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples = num_samples)
|
||||||
|
ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for _ in ds2.create_dict_iterator():
|
||||||
|
count = count + 1
|
||||||
|
assert count == 64
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for _ in ds3.create_dict_iterator():
|
||||||
|
count = count + 1
|
||||||
|
assert count == 64
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_case_0()
|
test_case_0()
|
||||||
test_case_1()
|
test_case_1()
|
||||||
|
@ -458,3 +526,6 @@ if __name__ == "__main__":
|
||||||
test_case_error_2()
|
test_case_error_2()
|
||||||
test_case_error_3()
|
test_case_error_3()
|
||||||
test_case_error_4()
|
test_case_error_4()
|
||||||
|
test_sequential_sampler()
|
||||||
|
test_distributed_sampler()
|
||||||
|
test_random_sampler()
|
||||||
|
|
|
@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False):
|
||||||
test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5])
|
test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampler_py_api():
|
||||||
|
sampler = ds.SequentialSampler().create()
|
||||||
|
sampler.set_num_rows(128)
|
||||||
|
sampler.set_num_samples(64)
|
||||||
|
sampler.initialize()
|
||||||
|
sampler.get_indices()
|
||||||
|
|
||||||
|
sampler = ds.RandomSampler().create()
|
||||||
|
sampler.set_num_rows(128)
|
||||||
|
sampler.set_num_samples(64)
|
||||||
|
sampler.initialize()
|
||||||
|
sampler.get_indices()
|
||||||
|
|
||||||
|
sampler = ds.DistributedSampler(8, 4).create()
|
||||||
|
sampler.set_num_rows(128)
|
||||||
|
sampler.set_num_samples(64)
|
||||||
|
sampler.initialize()
|
||||||
|
sampler.get_indices()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sequential_sampler(True)
|
test_sequential_sampler(True)
|
||||||
test_random_sampler(True)
|
test_random_sampler(True)
|
||||||
test_random_sampler_multi_iter(True)
|
test_random_sampler_multi_iter(True)
|
||||||
|
test_sampler_py_api()
|
||||||
|
|
Loading…
Reference in New Issue