forked from mindspore-Ecosystem/mindspore
!4383 Add distributedSampler to concatOP
Merge pull request !4383 from genglishuai/genglshOK
This commit is contained in:
commit
7a10e07fc6
|
@ -48,7 +48,7 @@ PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) {
|
|||
PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(
|
||||
*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) {
|
||||
|
|
|
@ -41,7 +41,7 @@ PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
|
|||
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m,
|
||||
"MindrecordDistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
|
|
|
@ -1081,6 +1081,25 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
std::shared_ptr<ConcatOp::Builder> builder = std::make_shared<ConcatOp::Builder>();
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
}
|
||||
if (key == "children_flag_and_nums") {
|
||||
auto childFlag = py::reinterpret_borrow<py::list>(value).cast<std::vector<std::pair<int, int>>>();
|
||||
(void)builder->SetChildrenFlagAndNums(childFlag);
|
||||
}
|
||||
if (key == "children_start_end_index") {
|
||||
auto childIndex = py::reinterpret_borrow<py::list>(value).cast<std::vector<std::pair<int, int>>>();
|
||||
(void)builder->SetChildrenStartEndIndex(childIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<ConcatOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
|
|
|
@ -29,15 +29,29 @@ namespace dataset {
|
|||
ConcatOp::Builder::Builder() {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
builder_sampler_ = nullptr;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
|
||||
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_);
|
||||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<DistributedSampler>(0, 1, 0, false);
|
||||
}
|
||||
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_,
|
||||
children_start_end_index_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the ConcatOp.
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums,
|
||||
std::vector<std::pair<int, int>> children_start_end_index)
|
||||
: PipelineOp(op_connector_size),
|
||||
children_num_(0),
|
||||
sampler_(sampler),
|
||||
children_flag_and_nums_(children_flag_and_nums),
|
||||
children_start_end_index_(children_start_end_index) {}
|
||||
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {}
|
||||
|
||||
// A function that prints info about the Operator
|
||||
|
@ -57,11 +71,20 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Main entry point for Concat
|
||||
Status ConcatOp::operator()() {
|
||||
// The children_num_ parameter needs to be put here
|
||||
children_num_ = static_cast<int32_t>(child_.size());
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<DataBuffer> buf;
|
||||
int eof_count = 0;
|
||||
int sample_number = 0;
|
||||
bool is_not_mappable = true;
|
||||
int num_shard = 1;
|
||||
int shard_index = 0;
|
||||
std::shared_ptr<DistributedSampler> distribute_sampler = std::dynamic_pointer_cast<DistributedSampler>(sampler_);
|
||||
if (distribute_sampler != nullptr) {
|
||||
num_shard = distribute_sampler->GetDeviceNum();
|
||||
shard_index = distribute_sampler->GetDeviceID();
|
||||
}
|
||||
|
||||
while (eof_count == 0) {
|
||||
for (int i = 0; i < children_num_; i++) {
|
||||
// 1. Read the first buffer
|
||||
|
@ -75,11 +98,39 @@ Status ConcatOp::operator()() {
|
|||
RETURN_IF_NOT_OK(Verify(i, buf));
|
||||
}
|
||||
// 3. Put the data into output_connector
|
||||
if (!children_flag_and_nums_.empty()) is_not_mappable = children_flag_and_nums_[i].first;
|
||||
while (!buf->eoe() && !buf->eof()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
||||
// if dataset is no mappable or generator dataset which source is yeild(cannot get the number of samples in
|
||||
// python layer), we use filtering to get data
|
||||
if (sample_number % num_shard == shard_index && (is_not_mappable || !children_flag_and_nums_[i].second)) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
||||
} else if (!is_not_mappable && children_flag_and_nums_[i].second) { // if dataset is mappable or generator
|
||||
// dataset which source is not yield
|
||||
// get the start and end subscripts of valid values
|
||||
int fv = children_start_end_index_[i].first, sv = children_start_end_index_[i].second;
|
||||
|
||||
// determine whether the data allocated to the current shard id is false data
|
||||
if ((fv == -1 && sv == -1) || (fv < sv && shard_index >= fv && shard_index < sv) ||
|
||||
(fv > sv && (shard_index >= fv || shard_index < sv))) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
||||
}
|
||||
}
|
||||
|
||||
// if dataSet is no mappable or generator dataset` which source is yeild, sample_number+=1
|
||||
if (is_not_mappable || !children_flag_and_nums_[i].second) {
|
||||
sample_number++;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
||||
}
|
||||
|
||||
// if dataset is mappable,We do't use filtering to pick data.
|
||||
// so sample_number plus the length of the entire dataset
|
||||
if (!is_not_mappable && children_flag_and_nums_[i].second) {
|
||||
sample_number += children_flag_and_nums_[i].second;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Add eoe buffer after get buffer from all child
|
||||
if (eof_count == 0) {
|
||||
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -42,15 +44,35 @@ class ConcatOp : public PipelineOp {
|
|||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new ConcatOp object
|
||||
Status Build(std::shared_ptr<ConcatOp> *);
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetChildrenFlagAndNums(std::vector<std::pair<int, int>> children_flag_and_nums) {
|
||||
children_flag_and_nums_ = std::move(children_flag_and_nums);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetChildrenStartEndIndex(std::vector<std::pair<int, int>> children_start_end_index) {
|
||||
children_start_end_index_ = std::move(children_start_end_index);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_op_connector_size_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
};
|
||||
|
||||
// Constructor of the ConcatOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param op_connector_size - connector size
|
||||
explicit ConcatOp(int32_t op_connector_size);
|
||||
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums,
|
||||
std::vector<std::pair<int, int>> children_start_end_index);
|
||||
|
||||
// Destructor
|
||||
~ConcatOp() = default;
|
||||
|
@ -90,6 +112,9 @@ class ConcatOp : public PipelineOp {
|
|||
std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name
|
||||
std::vector<DataType> data_type_;
|
||||
std::vector<dsize_t> data_rank_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -294,8 +294,9 @@ Status DeviceQueueOp::SendDataToCPU() {
|
|||
RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&curr_row));
|
||||
|
||||
if (!curr_row.empty()) {
|
||||
MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << ".";
|
||||
MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << ".";
|
||||
for (auto &tensor : curr_row) {
|
||||
MS_LOG(DEBUG) << "Feature size is " << tensor->SizeInBytes() << ".";
|
||||
}
|
||||
total_batch++;
|
||||
if (stop_send_) break;
|
||||
}
|
||||
|
|
|
@ -24,14 +24,16 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed, bool even_dist)
|
||||
uint32_t seed, int64_t offset, bool even_dist)
|
||||
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
|
||||
cnt_(0),
|
||||
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
|
||||
device_id_(dev_id),
|
||||
num_devices_(num_dev),
|
||||
shuffle_(shuffle),
|
||||
even_dist_(even_dist) {}
|
||||
even_dist_(even_dist),
|
||||
offset_(offset),
|
||||
non_empty_(true) {}
|
||||
|
||||
Status DistributedSampler::InitSampler() {
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
|
@ -44,14 +46,16 @@ Status DistributedSampler::InitSampler() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
||||
"fail to init DistributedSampler");
|
||||
rnd_.seed(seed_++);
|
||||
if (even_dist_) {
|
||||
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
||||
|
||||
if (offset_ != -1 || !even_dist_) {
|
||||
if (offset_ == -1) offset_ = 0;
|
||||
samples_per_buffer_ = (num_rows_ + offset_) / num_devices_;
|
||||
int remainder = (num_rows_ + offset_) % num_devices_;
|
||||
if (device_id_ < remainder) samples_per_buffer_++;
|
||||
if (device_id_ < offset_) samples_per_buffer_--;
|
||||
} else {
|
||||
int64_t mod = num_rows_ % num_devices_;
|
||||
samples_per_buffer_ = num_rows_ / num_devices_;
|
||||
if (mod > device_id_) {
|
||||
samples_per_buffer_++;
|
||||
}
|
||||
offset_ = 0;
|
||||
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
||||
}
|
||||
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
|
||||
if (shuffle_ == true) {
|
||||
|
@ -61,14 +65,29 @@ Status DistributedSampler::InitSampler() {
|
|||
}
|
||||
std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
|
||||
}
|
||||
if (!samples_per_buffer_) non_empty_ = false;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (cnt_ > samples_per_buffer_) {
|
||||
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
|
||||
} else if (cnt_ == samples_per_buffer_) {
|
||||
} else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else if (!samples_per_buffer_ && !non_empty_) {
|
||||
// If the buffer is empty, we add samples with subscript 0 in the current dataset.
|
||||
// This step is to make up for the solution that the code default buffer is not empty before.
|
||||
// We will remove this value in the concat phase
|
||||
non_empty_ = true;
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, 1));
|
||||
auto id_ptr = sample_ids->begin<int64_t>();
|
||||
// add index 0
|
||||
*id_ptr = 0;
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
|
@ -78,8 +97,18 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
|
|||
std::shared_ptr<Tensor> sample_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_));
|
||||
auto id_ptr = sample_ids->begin<int64_t>();
|
||||
bool flag_add_1 = false;
|
||||
while (cnt_ < samples_per_buffer_ && id_ptr != sample_ids->end<int64_t>()) {
|
||||
int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_;
|
||||
int64_t middle_value = num_devices_ * cnt_ + device_id_ - offset_;
|
||||
// if index < 0, we move back one place
|
||||
if (middle_value < 0) {
|
||||
samples_per_buffer_++;
|
||||
cnt_++;
|
||||
flag_add_1 = true;
|
||||
middle_value = num_devices_ * cnt_ + device_id_ - offset_;
|
||||
}
|
||||
int64_t sampled_id = middle_value % num_rows_;
|
||||
|
||||
if (shuffle_) {
|
||||
sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
|
||||
}
|
||||
|
@ -92,6 +121,12 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
|
|||
id_ptr++;
|
||||
cnt_++;
|
||||
}
|
||||
|
||||
// If 1 was added before, we will cut off 1 here
|
||||
if (flag_add_1) {
|
||||
samples_per_buffer_--;
|
||||
cnt_--;
|
||||
}
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
|
|
|
@ -34,11 +34,13 @@ class DistributedSampler : public Sampler {
|
|||
/// \param[in] shuffle Option to shuffle
|
||||
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
|
||||
/// result in different samples being picked
|
||||
/// \param[in] offset The starting position which the elements in the dataset are send to.The application
|
||||
/// scenario of this parameter is when the concatdataset is set distributedSampler
|
||||
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
|
||||
/// This option is not exposed in the python API. Current behavior is that the remainder will always
|
||||
/// be handled by the first n shards, n being the corresponding device id.
|
||||
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed = std::numeric_limits<uint32_t>::max(), bool even_dist = true);
|
||||
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);
|
||||
|
||||
/// \brief default destructor
|
||||
~DistributedSampler() = default;
|
||||
|
@ -55,6 +57,10 @@ class DistributedSampler : public Sampler {
|
|||
/// \return Status code
|
||||
Status ResetSampler() override;
|
||||
|
||||
int64_t GetDeviceID() { return device_id_; }
|
||||
|
||||
int64_t GetDeviceNum() { return num_devices_; }
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
|
@ -66,6 +72,8 @@ class DistributedSampler : public Sampler {
|
|||
std::mt19937 rnd_;
|
||||
std::vector<int64_t> shuffle_vec_;
|
||||
bool even_dist_;
|
||||
int64_t offset_;
|
||||
bool non_empty_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,9 +30,10 @@ namespace mindrecord {
|
|||
class ShardDistributedSample : public ShardSample {
|
||||
public:
|
||||
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed,
|
||||
int no_of_samples = 0);
|
||||
int no_of_samples = 0, int offset = -1);
|
||||
|
||||
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0);
|
||||
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0,
|
||||
int offset = -1);
|
||||
|
||||
void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class ShardSample : public ShardOperator {
|
|||
|
||||
ShardSample(int num, int den);
|
||||
|
||||
ShardSample(int num, int den, int par, int no_of_samples = 0);
|
||||
ShardSample(int num, int den, int par, int no_of_samples = 0, int offset = -1);
|
||||
|
||||
ShardSample(const std::vector<int64_t> &indices, uint32_t seed);
|
||||
|
||||
|
@ -50,10 +50,12 @@ class ShardSample : public ShardOperator {
|
|||
int partition_id_;
|
||||
int no_of_samples_;
|
||||
std::shared_ptr<ShardShuffle> shuffle_op_;
|
||||
std::vector<int64_t> nums_per_shard_;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> indices_;
|
||||
SamplerType sampler_type_;
|
||||
int offset_;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,8 +23,8 @@ using mindspore::MsLogLevel::ERROR;
|
|||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle,
|
||||
uint32_t seed, int no_of_samples)
|
||||
: ShardSample(1, num_shards, shard_id, no_of_samples),
|
||||
uint32_t seed, int no_of_samples, int offset)
|
||||
: ShardSample(1, num_shards, shard_id, no_of_samples, offset),
|
||||
shuffle_(shuffle),
|
||||
no_of_padded_samples_(no_of_padded_samples),
|
||||
first_epoch_(true) {
|
||||
|
@ -32,8 +32,8 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int
|
|||
}
|
||||
|
||||
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed,
|
||||
int no_of_samples)
|
||||
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {}
|
||||
int no_of_samples, int offset)
|
||||
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples, offset) {}
|
||||
|
||||
int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (no_of_padded_samples_ <= 0) {
|
||||
|
|
|
@ -28,7 +28,8 @@ ShardSample::ShardSample(int n)
|
|||
partition_id_(0),
|
||||
no_of_samples_(n),
|
||||
indices_({}),
|
||||
sampler_type_(kCustomTopNSampler) {}
|
||||
sampler_type_(kCustomTopNSampler),
|
||||
offset_(-1) {}
|
||||
|
||||
ShardSample::ShardSample(int num, int den)
|
||||
: numerator_(num),
|
||||
|
@ -36,15 +37,17 @@ ShardSample::ShardSample(int num, int den)
|
|||
partition_id_(0),
|
||||
no_of_samples_(0),
|
||||
indices_({}),
|
||||
sampler_type_(kCustomTopPercentSampler) {}
|
||||
sampler_type_(kCustomTopPercentSampler),
|
||||
offset_(-1) {}
|
||||
|
||||
ShardSample::ShardSample(int num, int den, int par, int no_of_samples)
|
||||
ShardSample::ShardSample(int num, int den, int par, int no_of_samples, int offset)
|
||||
: numerator_(num),
|
||||
denominator_(den),
|
||||
partition_id_(par),
|
||||
no_of_samples_(no_of_samples),
|
||||
indices_({}),
|
||||
sampler_type_(kCustomTopPercentSampler) {}
|
||||
sampler_type_(kCustomTopPercentSampler),
|
||||
offset_(offset) {}
|
||||
|
||||
ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed)
|
||||
: numerator_(0),
|
||||
|
@ -75,6 +78,19 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
|||
}
|
||||
|
||||
MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
||||
if (offset_ != -1) {
|
||||
int64_t old_v = 0;
|
||||
int num_rows_ = static_cast<int>(tasks.Size());
|
||||
for (int x = 0; x < denominator_; x++) {
|
||||
int samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
|
||||
int remainder = (num_rows_ + offset_) % denominator_;
|
||||
if (x < remainder) samples_per_buffer_++;
|
||||
if (x < offset_) samples_per_buffer_--;
|
||||
old_v += samples_per_buffer_;
|
||||
// nums_per_shard_ is used to save the current shard's ending index
|
||||
nums_per_shard_.push_back(old_v);
|
||||
}
|
||||
}
|
||||
int no_of_categories = static_cast<int>(tasks.categories);
|
||||
int total_no = static_cast<int>(tasks.Size()); // make sure task_size
|
||||
|
||||
|
@ -100,7 +116,6 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
|||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (tasks.permutation_.empty()) {
|
||||
ShardTask new_tasks;
|
||||
total_no = static_cast<int>(tasks.Size());
|
||||
|
@ -111,10 +126,20 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
|||
}
|
||||
} else {
|
||||
int count = 0;
|
||||
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
|
||||
count++;
|
||||
if (nums_per_shard_.empty()) {
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
|
||||
count++;
|
||||
}
|
||||
} else {
|
||||
// Get samples within a specific range
|
||||
size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
|
||||
for (; i < nums_per_shard_[partition_id_]; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
|
|
|
@ -20,15 +20,15 @@ can also create samplers with this module to sample data.
|
|||
|
||||
from .core import config
|
||||
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
|
||||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
|
||||
TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset
|
||||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset, \
|
||||
TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset, PaddedDataset
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler, Sampler
|
||||
from .engine.cache_client import DatasetCache
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
from .engine.graphdata import GraphData
|
||||
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "PaddedDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset",
|
||||
"CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler",
|
||||
|
|
|
@ -44,7 +44,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\
|
||||
check_paddeddataset
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
||||
|
||||
|
@ -2305,10 +2306,35 @@ class ConcatDataset(DatasetOp):
|
|||
if not isinstance(dataset, Dataset):
|
||||
raise TypeError("The parameter %s of concat has type error!" % (dataset))
|
||||
self.datasets = datasets
|
||||
self._sampler = None
|
||||
for data in datasets:
|
||||
self.children.append(data)
|
||||
data.parent.append(self)
|
||||
|
||||
self.children_sizes_ = [c.get_dataset_size() for c in self.children]
|
||||
"""
|
||||
_children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
|
||||
whether the data set is mappable. The second element of pair is length of the dataset
|
||||
"""
|
||||
self._children_flag_and_nums = []
|
||||
"""
|
||||
_children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
|
||||
the valid position of the dataset corresponding to the subscript when sampling
|
||||
"""
|
||||
self._children_start_end_index_ = []
|
||||
for index, child in enumerate(self.children):
|
||||
tem_list = [-1, -1]
|
||||
self._children_start_end_index_.append(tem_list)
|
||||
datasetLen = self.children_sizes_[index]
|
||||
if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"):
|
||||
datasetLen = 0
|
||||
self.children_sizes_[index] = 0
|
||||
|
||||
if isinstance(child, MappableDataset):
|
||||
self._children_flag_and_nums.append((0, datasetLen))
|
||||
else:
|
||||
self._children_flag_and_nums.append((1, datasetLen))
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
|
@ -2321,6 +2347,67 @@ class ConcatDataset(DatasetOp):
|
|||
self.dataset_size = sum(children_sizes)
|
||||
return self.dataset_size
|
||||
|
||||
def use_sampler(self, sampler):
|
||||
"""
|
||||
Set the distributedSampler to concat dataset
|
||||
|
||||
Args:
|
||||
sampler (Sampler): the sampler to use for the current dataset. Current support: DistributedSampler.
|
||||
|
||||
Raises:
|
||||
TypeError: If the sampler is not an istance of DistributedSampler
|
||||
ValueError: If the parameter shuffle of sampler is True
|
||||
ValueError: If the parameter NumSamples of sampler is not None.
|
||||
ValueError: If num_shards <=0.
|
||||
"""
|
||||
if not isinstance(sampler, samplers.DistributedSampler):
|
||||
raise TypeError("The parameter %s of concat should be DistributedSampler!" % (sampler))
|
||||
|
||||
if sampler.is_shuffled():
|
||||
raise ValueError("The parameter shuffle of DistributedSampler is not support to be true!")
|
||||
|
||||
if sampler.num_shards <= 0:
|
||||
raise ValueError("The parameter num_shards of concat should be positive int!")
|
||||
|
||||
if sampler.get_num_samples() is not None:
|
||||
raise ValueError("The parameter NumSamples of DistributedSampler is not support to be set!")
|
||||
|
||||
self._sampler = _select_sampler(None, sampler, None, None, None)
|
||||
cumulative_samples_nums = 0
|
||||
for index, child in enumerate(self.children):
|
||||
|
||||
if isinstance(child, BatchDataset):
|
||||
raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child))
|
||||
|
||||
if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]:
|
||||
|
||||
tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1]
|
||||
|
||||
if not self._children_flag_and_nums[index][1] >= sampler.num_shards:
|
||||
if tem_value < sampler.num_shards:
|
||||
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
||||
self._children_start_end_index_[index][1] = tem_value
|
||||
else:
|
||||
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
||||
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
|
||||
|
||||
|
||||
tem_sampler = copy.deepcopy(sampler)
|
||||
tem_sampler.set_offset(cumulative_samples_nums)
|
||||
child.sampler = tem_sampler
|
||||
|
||||
cumulative_samples_nums += self.children_sizes_[index]
|
||||
cumulative_samples_nums %= sampler.num_shards
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
||||
if self._sampler is not None:
|
||||
args["sampler"] = self._sampler
|
||||
args["children_flag_and_nums"] = self._children_flag_and_nums
|
||||
args["children_start_end_index"] = self._children_start_end_index_
|
||||
return args
|
||||
|
||||
|
||||
class RenameDataset(DatasetOp):
|
||||
"""
|
||||
|
@ -3307,7 +3394,6 @@ class GeneratorDataset(MappableDataset):
|
|||
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
||||
new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
|
||||
new_op.dataset_size = self.dataset_size
|
||||
|
||||
new_op.sampler = copy.deepcopy(self.sampler)
|
||||
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||
if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||
|
@ -5276,6 +5362,53 @@ class NumpySlicesDataset(GeneratorDataset):
|
|||
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
||||
num_shards=num_shards, shard_id=shard_id)
|
||||
|
||||
class _PaddedDataset:
|
||||
"""
|
||||
Mainly for combining false samples provided by users into a dataset.
|
||||
|
||||
Args:
|
||||
padded_samples (list(dict)): the data provided by user to added to initial Dataset
|
||||
"""
|
||||
def __init__(self, padded_samples):
|
||||
self.column_names = list(padded_samples[0].keys())
|
||||
self.padded_samples = padded_samples
|
||||
|
||||
def __getitem__(self, item):
|
||||
return (self.padded_samples[item][key] for key in self.column_names)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.padded_samples)
|
||||
|
||||
class PaddedDataset(GeneratorDataset):
|
||||
"""
|
||||
Create a dataset with fake data provided by user. Mainly used to add to the original data set
|
||||
and assign it to the corresponding shard.
|
||||
|
||||
Args:
|
||||
padded_samples (list(dict)): the samples provided by user
|
||||
|
||||
Raises:
|
||||
TypeError: If padded_samples is not an instance of list.
|
||||
TypeError: If the element of padded_samples is not an instance of dict.
|
||||
ValueError: If the padded_samples is empty.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
|
||||
>>> ds1 = ds.PaddedDataset(data1)
|
||||
"""
|
||||
@check_paddeddataset
|
||||
def __init__(self, padded_samples):
|
||||
dataset = _PaddedDataset(padded_samples)
|
||||
super().__init__(dataset, column_names=dataset.column_names,
|
||||
num_shards=None,
|
||||
shard_id=None, shuffle=False)
|
||||
self._dataset_size = len(dataset.padded_samples)
|
||||
self.padded_samples = padded_samples
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self._dataset_size
|
||||
|
||||
|
||||
class BuildVocabDataset(DatasetOp):
|
||||
"""
|
||||
|
|
|
@ -223,7 +223,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
shard_id (int): Shard ID of the current shard within num_shards.
|
||||
shuffle (bool, optional): If true, the indices are shuffled (default=True).
|
||||
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
||||
|
||||
offset(int, optional): Offset from shard when the element of dataset is allocated
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
|
@ -239,7 +239,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
ValueError: If shuffle is not a boolean value.
|
||||
"""
|
||||
|
||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None):
|
||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
||||
if num_shards <= 0:
|
||||
raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards))
|
||||
|
||||
|
@ -258,13 +258,15 @@ class DistributedSampler(BuiltinSampler):
|
|||
self.shard_id = shard_id
|
||||
self.shuffle = shuffle
|
||||
self.seed = 0
|
||||
self.offset = offset
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||
self.seed += 1
|
||||
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id,
|
||||
self.shuffle, self.seed, self.offset)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -272,7 +274,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
def create_for_minddataset(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
|
||||
self.seed, num_samples)
|
||||
self.seed, num_samples, self.offset)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -289,6 +291,10 @@ class DistributedSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def set_offset(self, offset):
|
||||
self.offset = offset
|
||||
return self
|
||||
|
||||
|
||||
class PKSampler(BuiltinSampler):
|
||||
"""
|
||||
|
|
|
@ -1156,3 +1156,20 @@ def check_numpyslicesdataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_paddeddataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(PaddedDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
paddedSamples = param_dict.get("padded_samples")
|
||||
if not paddedSamples:
|
||||
raise ValueError("Argument padded_samples cannot be empty")
|
||||
type_check(paddedSamples, (list,), "padded_samples")
|
||||
type_check(paddedSamples[0], (dict,), "padded_element")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -48,7 +48,7 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
|
|||
uint64_t num_samples = 7;
|
||||
|
||||
// create sampler with replacement = true
|
||||
DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false);
|
||||
DistributedSampler m_sampler(num_samples, 2, 0, false, 0, -1, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -74,7 +74,7 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
|
|||
uint64_t num_samples = 7;
|
||||
|
||||
// create sampler with replacement = true
|
||||
DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false);
|
||||
DistributedSampler m_sampler(num_samples, 2, 1, false, 0, -1, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -100,7 +100,7 @@ TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
|
|||
uint64_t num_samples = 2;
|
||||
|
||||
// create sampler with replacement = true
|
||||
DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false);
|
||||
DistributedSampler m_sampler(num_samples, 3, 2, false, 0, -1, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
|
|
@ -0,0 +1,364 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.mindrecord import FileWriter
|
||||
FILES_NUM = 4
|
||||
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
|
||||
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
|
||||
|
||||
def generator_5():
|
||||
for i in range(0, 5):
|
||||
yield (np.array([i]),)
|
||||
|
||||
def generator_8():
|
||||
for i in range(5, 8):
|
||||
yield (np.array([i]),)
|
||||
|
||||
def generator_10():
|
||||
for i in range(0, 10):
|
||||
yield (np.array([i]),)
|
||||
|
||||
def generator_20():
|
||||
for i in range(10, 20):
|
||||
yield (np.array([i]),)
|
||||
|
||||
def generator_30():
|
||||
for i in range(20, 30):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def test_TFRecord_Padded():
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
result_list = [[159109, 2], [192607, 3], [179251, 4], [1, 5]]
|
||||
verify_list = []
|
||||
shard_num = 4
|
||||
for i in range(shard_num):
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"],
|
||||
shuffle=False, shard_equal_rows=True)
|
||||
|
||||
padded_samples = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
|
||||
{'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
|
||||
{'image': np.zeros(5, np.uint8)}]
|
||||
|
||||
padded_ds = ds.PaddedDataset(padded_samples)
|
||||
concat_ds = data + padded_ds
|
||||
testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
|
||||
concat_ds.use_sampler(testsampler)
|
||||
shard_list = []
|
||||
for item in concat_ds.create_dict_iterator():
|
||||
shard_list.append(len(item['image']))
|
||||
verify_list.append(shard_list)
|
||||
assert verify_list == result_list
|
||||
|
||||
def test_GeneratorDataSet_Padded():
|
||||
result_list = []
|
||||
for i in range(10):
|
||||
tem_list = []
|
||||
tem_list.append(i)
|
||||
tem_list.append(10+i)
|
||||
result_list.append(tem_list)
|
||||
|
||||
verify_list = []
|
||||
data1 = ds.GeneratorDataset(generator_20, ["col1"])
|
||||
data2 = ds.GeneratorDataset(generator_10, ["col1"])
|
||||
data3 = data2 + data1
|
||||
shard_num = 10
|
||||
for i in range(shard_num):
|
||||
distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
|
||||
data3.use_sampler(distributed_sampler)
|
||||
tem_list = []
|
||||
for ele in data3.create_dict_iterator():
|
||||
tem_list.append(ele['col1'][0])
|
||||
verify_list.append(tem_list)
|
||||
|
||||
assert verify_list == result_list
|
||||
|
||||
def test_Reapeat_afterPadded():
|
||||
result_list = [1, 3, 5, 7]
|
||||
verify_list = []
|
||||
|
||||
data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
|
||||
{'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
|
||||
{'image': np.zeros(5, np.uint8)}]
|
||||
data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
|
||||
{'image': np.zeros(8, np.uint8)}]
|
||||
|
||||
ds1 = ds.PaddedDataset(data1)
|
||||
ds2 = ds.PaddedDataset(data2)
|
||||
ds3 = ds1 + ds2
|
||||
|
||||
testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
|
||||
ds3.use_sampler(testsampler)
|
||||
repeat_num = 2
|
||||
ds3 = ds3.repeat(repeat_num)
|
||||
for item in ds3.create_dict_iterator():
|
||||
verify_list.append(len(item['image']))
|
||||
|
||||
assert verify_list == result_list * repeat_num
|
||||
|
||||
def test_bath_afterPadded():
|
||||
data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
|
||||
{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
|
||||
{'image': np.zeros(1, np.uint8)}]
|
||||
data2 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
|
||||
{'image': np.zeros(1, np.uint8)}]
|
||||
|
||||
ds1 = ds.PaddedDataset(data1)
|
||||
ds2 = ds.PaddedDataset(data2)
|
||||
ds3 = ds1 + ds2
|
||||
|
||||
testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
|
||||
ds3.use_sampler(testsampler)
|
||||
|
||||
ds4 = ds3.batch(2)
|
||||
assert sum([1 for _ in ds4]) == 2
|
||||
|
||||
def test_Unevenly_distributed():
|
||||
result_list = [[1, 4, 7], [2, 5, 8], [3, 6]]
|
||||
verify_list = []
|
||||
|
||||
data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
|
||||
{'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
|
||||
{'image': np.zeros(5, np.uint8)}]
|
||||
data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
|
||||
{'image': np.zeros(8, np.uint8)}]
|
||||
|
||||
testsampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=1)
|
||||
|
||||
ds1 = ds.PaddedDataset(data1)
|
||||
ds2 = ds.PaddedDataset(data2)
|
||||
ds3 = ds1 + ds2
|
||||
numShard = 3
|
||||
for i in range(numShard):
|
||||
tem_list = []
|
||||
testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
|
||||
ds3.use_sampler(testsampler)
|
||||
for item in ds3.create_dict_iterator():
|
||||
tem_list.append(len(item['image']))
|
||||
verify_list.append(tem_list)
|
||||
assert verify_list == result_list
|
||||
|
||||
def test_three_datasets_connected():
|
||||
result_list = []
|
||||
for i in range(10):
|
||||
tem_list = []
|
||||
tem_list.append(i)
|
||||
tem_list.append(10 + i)
|
||||
tem_list.append(20 + i)
|
||||
result_list.append(tem_list)
|
||||
|
||||
verify_list = []
|
||||
data1 = ds.GeneratorDataset(generator_10, ["col1"])
|
||||
data2 = ds.GeneratorDataset(generator_20, ["col1"])
|
||||
data3 = ds.GeneratorDataset(generator_30, ["col1"])
|
||||
data4 = data1 + data2 + data3
|
||||
shard_num = 10
|
||||
for i in range(shard_num):
|
||||
distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
|
||||
data4.use_sampler(distributed_sampler)
|
||||
tem_list = []
|
||||
for ele in data4.create_dict_iterator():
|
||||
tem_list.append(ele['col1'][0])
|
||||
verify_list.append(tem_list)
|
||||
|
||||
assert verify_list == result_list
|
||||
|
||||
def test_raise_error():
|
||||
data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
|
||||
{'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
|
||||
{'image': np.zeros(5, np.uint8)}]
|
||||
data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
|
||||
{'image': np.zeros(8, np.uint8)}]
|
||||
|
||||
ds1 = ds.PaddedDataset(data1)
|
||||
ds4 = ds1.batch(2)
|
||||
ds2 = ds.PaddedDataset(data2)
|
||||
ds3 = ds4 + ds2
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
|
||||
ds3.use_sampler(testsampler)
|
||||
assert excinfo.type == 'TypeError'
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
otherSampler = ds.SequentialSampler()
|
||||
ds3.use_sampler(otherSampler)
|
||||
assert excinfo.type == 'TypeError'
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=True, num_samples=None)
|
||||
ds3.use_sampler(testsampler)
|
||||
assert excinfo.type == 'ValueError'
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
|
||||
ds3.use_sampler(testsampler)
|
||||
assert excinfo.type == 'ValueError'
|
||||
|
||||
def test_imagefolden_padded():
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
data = ds.ImageFolderDatasetV2(DATA_DIR)
|
||||
|
||||
data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
|
||||
{'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
|
||||
{'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
|
||||
{'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
|
||||
{'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
|
||||
{'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
|
||||
|
||||
data2 = ds.PaddedDataset(data1)
|
||||
data3 = data + data2
|
||||
testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
|
||||
data3.use_sampler(testsampler)
|
||||
assert sum([1 for _ in data3]) == 10
|
||||
verify_list = []
|
||||
|
||||
for ele in data3.create_dict_iterator():
|
||||
verify_list.append(len(ele['image']))
|
||||
assert verify_list[8] == 1
|
||||
assert verify_list[9] == 6
|
||||
|
||||
def test_more_shard_padded():
|
||||
result_list = []
|
||||
for i in range(8):
|
||||
result_list.append(1)
|
||||
result_list.append(0)
|
||||
|
||||
data1 = ds.GeneratorDataset(generator_5, ["col1"])
|
||||
data2 = ds.GeneratorDataset(generator_8, ["col1"])
|
||||
data3 = data1 + data2
|
||||
vertifyList = []
|
||||
numShard = 9
|
||||
for i in range(numShard):
|
||||
tem_list = []
|
||||
testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
|
||||
data3.use_sampler(testsampler)
|
||||
for item in data3.create_dict_iterator():
|
||||
tem_list.append(item['col1'])
|
||||
vertifyList.append(tem_list)
|
||||
|
||||
assert [len(ele) for ele in vertifyList] == result_list
|
||||
|
||||
vertifyList1 = []
|
||||
result_list1 = []
|
||||
for i in range(8):
|
||||
result_list1.append([i+1])
|
||||
result_list1.append([])
|
||||
|
||||
data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
|
||||
{'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
|
||||
{'image': np.zeros(5, np.uint8)}]
|
||||
data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
|
||||
{'image': np.zeros(8, np.uint8)}]
|
||||
|
||||
ds1 = ds.PaddedDataset(data1)
|
||||
ds2 = ds.PaddedDataset(data2)
|
||||
ds3 = ds1 + ds2
|
||||
|
||||
for i in range(numShard):
|
||||
tem_list = []
|
||||
testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
|
||||
ds3.use_sampler(testsampler)
|
||||
for item in ds3.create_dict_iterator():
|
||||
tem_list.append(len(item['image']))
|
||||
vertifyList1.append(tem_list)
|
||||
|
||||
assert vertifyList1 == result_list1
|
||||
|
||||
def get_data(dir_name):
|
||||
"""
|
||||
usage: get data from imagenet dataset
|
||||
|
||||
params:
|
||||
dir_name: directory containing folder images and annotation information
|
||||
"""
|
||||
if not os.path.isdir(dir_name):
|
||||
raise IOError("Directory {} not exists".format(dir_name))
|
||||
img_dir = os.path.join(dir_name, "images")
|
||||
ann_file = os.path.join(dir_name, "annotation.txt")
|
||||
with open(ann_file, "r") as file_reader:
|
||||
lines = file_reader.readlines()
|
||||
|
||||
data_list = []
|
||||
for i, line in enumerate(lines):
|
||||
try:
|
||||
filename, label = line.split(",")
|
||||
label = label.strip("\n")
|
||||
with open(os.path.join(img_dir, filename), "rb") as file_reader:
|
||||
img = file_reader.read()
|
||||
data_json = {"id": i,
|
||||
"file_name": filename,
|
||||
"data": img,
|
||||
"label": int(label)}
|
||||
data_list.append(data_json)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
return data_list
|
||||
|
||||
@pytest.fixture(name="remove_mindrecord_file")
|
||||
def add_and_remove_cv_file():
|
||||
"""add/remove cv file"""
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
try:
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
yield "yield_cv_data"
|
||||
except Exception as error:
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
raise error
|
||||
else:
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
def test_Mindrecord_Padded(remove_mindrecord_file):
|
||||
result_list = []
|
||||
verify_list = [[1, 2], [3, 4], [5, 11], [6, 12], [7, 13], [8, 14], [9], [10]]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", ['file_name'], num_readers, shuffle=False)
|
||||
data1 = [{'file_name': np.array(b'image_00011.jpg', dtype='|S15')},
|
||||
{'file_name': np.array(b'image_00012.jpg', dtype='|S15')},
|
||||
{'file_name': np.array(b'image_00013.jpg', dtype='|S15')},
|
||||
{'file_name': np.array(b'image_00014.jpg', dtype='|S15')}]
|
||||
ds1 = ds.PaddedDataset(data1)
|
||||
ds2 = data_set + ds1
|
||||
shard_num = 8
|
||||
for i in range(shard_num):
|
||||
testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
|
||||
ds2.use_sampler(testsampler)
|
||||
tem_list = []
|
||||
for ele in ds2.create_dict_iterator():
|
||||
tem_list.append(int(ele['file_name'].tostring().decode().lstrip('image_').rstrip('.jpg')))
|
||||
result_list.append(tem_list)
|
||||
assert result_list == verify_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_TFRecord_Padded()
|
||||
test_GeneratorDataSet_Padded()
|
||||
test_Reapeat_afterPadded()
|
||||
test_bath_afterPadded()
|
||||
test_Unevenly_distributed()
|
||||
test_three_datasets_connected()
|
||||
test_raise_error()
|
||||
test_imagefolden_padded()
|
||||
test_more_shard_padded()
|
||||
test_Mindrecord_Padded(add_and_remove_cv_file)
|
Loading…
Reference in New Issue