forked from mindspore-Ecosystem/mindspore
add split in minddataset
This commit is contained in:
parent
6089d58d8d
commit
d4f8f57c7e
|
@ -391,35 +391,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) {
|
||||
if (args["partitions"].is_none()) {
|
||||
std::string err_msg = "Error: partitions is not set (None)";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
py::list list = py::reinterpret_borrow<py::list>(args["partitions"]);
|
||||
for (auto l : list) {
|
||||
if (!l.is_none()) {
|
||||
in_partitions->push_back(ToInt(l));
|
||||
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded) {
|
||||
auto sampler = py::reinterpret_borrow<py::object>(handle);
|
||||
auto create = sampler.attr("create_for_minddataset");
|
||||
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
|
||||
while (op != nullptr) {
|
||||
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
|
||||
if (sampler_op && num_padded > 0) {
|
||||
sampler_op->SetNumPaddedSamples(num_padded);
|
||||
stack_ops.push(sampler_op);
|
||||
} else {
|
||||
stack_ops.push(op);
|
||||
}
|
||||
op = op->GetChildOp();
|
||||
}
|
||||
|
||||
if (in_partitions->size() != 2) {
|
||||
std::string err_msg = "Error: partitions is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
while (!stack_ops.empty()) {
|
||||
operators->push_back(stack_ops.top());
|
||||
stack_ops.pop();
|
||||
}
|
||||
|
||||
constexpr int kMaxPartitions = 1024;
|
||||
if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) {
|
||||
std::string err_msg = "Error: partitions is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) {
|
||||
std::string err_msg = "Error: partitions is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -460,34 +452,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
|||
(void)builder->SetNumMindRecordWorkers(ToInt(value));
|
||||
} else if (key == "block_reader" && ToBool(value) == true) {
|
||||
(void)builder->SetBlockReader();
|
||||
} else if (key == "shuffle_option" && ToBool(value) == true) {
|
||||
if (!args["partitions"].is_none()) continue;
|
||||
uint32_t seed = GetSeed();
|
||||
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
|
||||
} else if (key == "sampler") {
|
||||
auto sampler = py::reinterpret_borrow<py::object>(value);
|
||||
auto create = sampler.attr("_create_for_minddataset");
|
||||
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
operators.push_back(op);
|
||||
int num_padded = 0;
|
||||
if (!args["num_padded"].is_none()) {
|
||||
num_padded = ToInt(args["num_padded"]);
|
||||
}
|
||||
RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> in_partitions;
|
||||
if (!args["partitions"].is_none()) {
|
||||
auto ret = CheckMindRecordPartitionInfo(args, &in_partitions);
|
||||
if (Status::OK() != ret) {
|
||||
return ret;
|
||||
}
|
||||
auto shuffle = ToBool(args["shuffle_option"]);
|
||||
int num_padded = 0;
|
||||
if (!args["num_padded"].is_none()) {
|
||||
num_padded = ToInt(args["num_padded"]);
|
||||
}
|
||||
operators.push_back(
|
||||
std::make_shared<mindrecord::ShardDistributedSample>(in_partitions[0], in_partitions[1], num_padded, shuffle, 0));
|
||||
}
|
||||
|
||||
if (!operators.empty()) {
|
||||
(void)builder->SetOperators(operators);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
@ -108,10 +109,12 @@ class DEPipeline {
|
|||
|
||||
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr);
|
||||
|
||||
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded);
|
||||
|
||||
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
|
|
@ -71,6 +71,7 @@
|
|||
#include "mindrecord/include/shard_pk_sample.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_sequential_sample.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
@ -165,8 +166,8 @@ void bindDatasetOps(py::module *m) {
|
|||
const int64_t num_padded) {
|
||||
int64_t count = 0;
|
||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
||||
auto create = sampler.attr("_create_for_minddataset");
|
||||
if (py::hasattr(sampler, "create_for_minddataset")) {
|
||||
auto create = sampler.attr("create_for_minddataset");
|
||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
}
|
||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
|
||||
|
@ -486,7 +487,9 @@ void bindSamplerOps(py::module *m) {
|
|||
.def("add_child",
|
||||
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
|
||||
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
|
||||
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
|
||||
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
|
||||
|
||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
|
||||
|
@ -518,6 +521,22 @@ void bindSamplerOps(py::module *m) {
|
|||
}
|
||||
}));
|
||||
|
||||
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t>());
|
||||
|
||||
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
|
||||
*m, "MindrecordRandomSampler")
|
||||
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
|
||||
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
|
||||
}));
|
||||
|
||||
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
|
||||
.def(py::init([](int num_samples, int start_index) {
|
||||
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
|
||||
}));
|
||||
|
||||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
||||
|
||||
|
|
|
@ -31,6 +31,10 @@ class ShardDistributedSample : public ShardSample {
|
|||
public:
|
||||
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);
|
||||
|
||||
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);
|
||||
|
||||
void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }
|
||||
|
||||
~ShardDistributedSample() override{};
|
||||
|
||||
MSRStatus PreExecute(ShardTask &tasks) override;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
|
||||
#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include "mindrecord/include/shard_task.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,6 +38,14 @@ class ShardOperator {
|
|||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
virtual bool HasChildOp() { return child_op_ != nullptr; }
|
||||
|
||||
virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) {
|
||||
if (child_op != nullptr) child_op_ = child_op;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }
|
||||
|
||||
virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }
|
||||
|
||||
|
@ -44,7 +53,10 @@ class ShardOperator {
|
|||
|
||||
virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }
|
||||
|
||||
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; }
|
||||
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<ShardOperator> child_op_ = nullptr;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
|
@ -44,6 +45,7 @@
|
|||
#include "mindrecord/include/common/shard_utils.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_column.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_error.h"
|
||||
#include "mindrecord/include/shard_index_generator.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
|
|
|
@ -48,10 +48,10 @@ class ShardSample : public ShardOperator {
|
|||
int numerator_;
|
||||
int denominator_;
|
||||
int partition_id_;
|
||||
int no_of_samples_;
|
||||
std::shared_ptr<ShardShuffle> shuffle_op_;
|
||||
|
||||
private:
|
||||
int no_of_samples_;
|
||||
std::vector<int64_t> indices_;
|
||||
SamplerType sampler_type_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_
|
||||
#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
class ShardSequentialSample : public ShardSample {
|
||||
public:
|
||||
ShardSequentialSample(int n, int offset);
|
||||
|
||||
ShardSequentialSample(float per, float per_offset);
|
||||
|
||||
~ShardSequentialSample() override{};
|
||||
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
private:
|
||||
int offset_;
|
||||
float per_;
|
||||
float per_offset_;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_
|
|
@ -26,12 +26,20 @@ class ShardShuffle : public ShardOperator {
|
|||
public:
|
||||
explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory);
|
||||
|
||||
ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
|
||||
ShuffleType shuffle_type = kShuffleSample);
|
||||
|
||||
~ShardShuffle() override{};
|
||||
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
private:
|
||||
uint32_t shuffle_seed_;
|
||||
int64_t no_of_samples_;
|
||||
bool replacement_;
|
||||
bool reshuffle_each_epoch_;
|
||||
ShuffleType shuffle_type_;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -792,24 +792,51 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
}
|
||||
|
||||
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded) {
|
||||
const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) {
|
||||
if (SUCCESS != Init(file_paths, load_dataset)) {
|
||||
return FAILED;
|
||||
}
|
||||
int64_t num_samples = num_rows_;
|
||||
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||
std::string category_field = category_op->GetCategoryField();
|
||||
auto num_classes = GetNumClasses(category_field);
|
||||
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
num_samples = op->GetNumSamples(num_rows_, 0);
|
||||
if (-1 == num_samples) {
|
||||
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
|
||||
return FAILED;
|
||||
bool root = true;
|
||||
std::stack<std::shared_ptr<ShardOperator>> stack_ops;
|
||||
std::shared_ptr<ShardOperator> op(ops);
|
||||
while (op != nullptr) {
|
||||
stack_ops.push(op);
|
||||
op = op->GetChildOp();
|
||||
}
|
||||
while (!stack_ops.empty()) {
|
||||
op = stack_ops.top();
|
||||
stack_ops.pop();
|
||||
if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
|
||||
num_samples = op->GetNumSamples(num_samples, 0);
|
||||
if (num_padded > 0 && root == true) {
|
||||
num_samples += num_padded;
|
||||
MS_LOG(DEBUG) << "Padding samples work on shuffle sampler.";
|
||||
root = false;
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||
std::string category_field = category_op->GetCategoryField();
|
||||
auto num_classes = GetNumClasses(category_field);
|
||||
num_samples = category_op->GetNumSamples(num_samples, num_classes);
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
|
||||
if (root == true) {
|
||||
sampler_op->SetNumPaddedSamples(num_padded);
|
||||
num_samples = op->GetNumSamples(num_samples, 0);
|
||||
if (-1 == num_samples) {
|
||||
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
|
||||
return FAILED;
|
||||
}
|
||||
root = false;
|
||||
}
|
||||
} else {
|
||||
num_samples = op->GetNumSamples(num_samples, 0);
|
||||
}
|
||||
} else {
|
||||
if (num_padded > 0) num_samples += num_padded;
|
||||
}
|
||||
} else {
|
||||
if (num_padded > 0) num_samples += num_padded;
|
||||
}
|
||||
*count = num_samples;
|
||||
return SUCCESS;
|
||||
|
@ -1385,12 +1412,16 @@ void ShardReader::Reset() {
|
|||
}
|
||||
|
||||
void ShardReader::ShuffleTask() {
|
||||
if (block_reader_) return;
|
||||
// exist shuffle and distributed sampler in ops, skip shuffle
|
||||
bool has_sharding = false;
|
||||
for (const auto &op : operators_) {
|
||||
if (block_reader_) {
|
||||
continue;
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
has_sharding = true;
|
||||
}
|
||||
|
||||
if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
|
||||
}
|
||||
for (const auto &op : operators_) {
|
||||
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
|
||||
if (SUCCESS != (*op)(tasks_)) {
|
||||
MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
|
||||
}
|
||||
|
|
|
@ -31,6 +31,9 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int
|
|||
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
|
||||
}
|
||||
|
||||
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
|
||||
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}
|
||||
|
||||
int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (no_of_padded_samples_ <= 0) {
|
||||
if (dataset_size % denominator_ == 0) {
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindrecord/include/shard_sequential_sample.h"
|
||||
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::ERROR;
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardSequentialSample::ShardSequentialSample(int n, int offset)
|
||||
: ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {}
|
||||
|
||||
ShardSequentialSample::ShardSequentialSample(float per, float per_offset)
|
||||
: ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {}
|
||||
|
||||
int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
|
||||
return dataset_size;
|
||||
}
|
||||
if (per_ > kEpsilon && per_ <= 1.0f) {
|
||||
return dataset_size * kEpsilon;
|
||||
}
|
||||
return no_of_samples_;
|
||||
}
|
||||
|
||||
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
|
||||
int total_no = static_cast<int>(tasks.Size());
|
||||
int taking;
|
||||
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
|
||||
taking = total_no;
|
||||
} else if (per_ > kEpsilon && per_ <= 1.0f) {
|
||||
taking = total_no * kEpsilon;
|
||||
} else {
|
||||
taking = no_of_samples_;
|
||||
}
|
||||
|
||||
if (tasks.permutation_.empty()) {
|
||||
ShardTask new_tasks;
|
||||
total_no = static_cast<int>(tasks.Size());
|
||||
for (int i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
} else { // shuffled
|
||||
ShardTask new_tasks;
|
||||
if (taking > static_cast<int>(tasks.permutation_.size())) {
|
||||
return FAILED;
|
||||
}
|
||||
total_no = static_cast<int>(tasks.permutation_.size());
|
||||
for (size_t i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
|
@ -21,17 +21,52 @@
|
|||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
|
||||
: shuffle_seed_(seed), shuffle_type_(shuffle_type) {}
|
||||
: shuffle_seed_(seed),
|
||||
no_of_samples_(0),
|
||||
replacement_(false),
|
||||
reshuffle_each_epoch_(true),
|
||||
shuffle_type_(shuffle_type) {}
|
||||
|
||||
ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
|
||||
ShuffleType shuffle_type)
|
||||
: shuffle_seed_(seed),
|
||||
no_of_samples_(no_of_samples),
|
||||
replacement_(replacement),
|
||||
reshuffle_each_epoch_(reshuffle_each_epoch),
|
||||
shuffle_type_(shuffle_type) {}
|
||||
|
||||
int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (replacement_) {
|
||||
return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
|
||||
}
|
||||
return dataset_size;
|
||||
}
|
||||
|
||||
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
||||
if (tasks.categories < 1) {
|
||||
return FAILED;
|
||||
}
|
||||
if (shuffle_type_ == kShuffleSample) {
|
||||
if (shuffle_type_ == kShuffleSample) { // shuffle each sample
|
||||
if (tasks.permutation_.empty() == true) {
|
||||
tasks.MakePerm();
|
||||
}
|
||||
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
|
||||
if (replacement_ == true) {
|
||||
ShardTask new_tasks;
|
||||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.Size());
|
||||
}
|
||||
if (no_of_samples_ <= 0) {
|
||||
MS_LOG(ERROR) << "no_of_samples need to be positive.";
|
||||
return FAILED;
|
||||
}
|
||||
new_tasks.task_list_.reserve(no_of_samples_);
|
||||
for (uint32_t i = 0; i < no_of_samples_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetRandomTask());
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
} else {
|
||||
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
|
||||
}
|
||||
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
|
||||
uint32_t individual_size = tasks.Size() / tasks.categories;
|
||||
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
|
||||
|
@ -46,7 +81,7 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
|||
}
|
||||
}
|
||||
}
|
||||
shuffle_seed_++;
|
||||
if (reshuffle_each_epoch_) shuffle_seed_++;
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -72,6 +72,7 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
|
|||
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
|
||||
return task_list_[dis(gen)];
|
||||
}
|
||||
|
||||
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) {
|
||||
ShardTask res;
|
||||
if (category_tasks.empty()) return res;
|
||||
|
|
|
@ -1015,10 +1015,8 @@ class Dataset:
|
|||
|
||||
def get_distribution(output_dataset):
|
||||
dev_id = 0
|
||||
if isinstance(output_dataset, (MindDataset)):
|
||||
return output_dataset.distribution, dev_id
|
||||
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
|
||||
ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
|
||||
ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)):
|
||||
sampler = output_dataset.sampler
|
||||
if isinstance(sampler, samplers.DistributedSampler):
|
||||
dev_id = sampler.shard_id
|
||||
|
@ -2670,7 +2668,7 @@ class MnistDataset(MappableDataset):
|
|||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class MindDataset(SourceDataset):
|
||||
class MindDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads from shard files and database.
|
||||
|
||||
|
@ -2687,11 +2685,13 @@ class MindDataset(SourceDataset):
|
|||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, sampler is exclusive
|
||||
with shuffle and block_reader). Support list: SubsetRandomSampler,
|
||||
PkSampler.
|
||||
PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
|
||||
padded_sample (dict, optional): Samples will be appended to dataset, which
|
||||
keys are the same as column_list.
|
||||
num_padded (int, optional): Number of padding samples.Dataset size
|
||||
plus num_padded should be divisible by num_shards.
|
||||
num_samples (int, optional): The number of samples to be included in the dataset
|
||||
(default=None, all samples).
|
||||
|
||||
Raises:
|
||||
ValueError: If num_shards is specified but shard_id is None.
|
||||
|
@ -2703,7 +2703,7 @@ class MindDataset(SourceDataset):
|
|||
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
|
||||
shuffle=None, num_shards=None, shard_id=None,
|
||||
block_reader=False, sampler=None, padded_sample=None,
|
||||
num_padded=None):
|
||||
num_padded=None, num_samples=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
if isinstance(dataset_file, list):
|
||||
self.load_dataset = False
|
||||
|
@ -2712,15 +2712,10 @@ class MindDataset(SourceDataset):
|
|||
self.dataset_file = dataset_file
|
||||
self.columns_list = columns_list
|
||||
self.shuffle_option = shuffle
|
||||
self.distribution = ""
|
||||
self.sampler = sampler
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
if num_shards is None or shard_id is None:
|
||||
self.partitions = None
|
||||
else:
|
||||
self.partitions = [num_shards, shard_id]
|
||||
|
||||
if block_reader is True and self.partitions is not None:
|
||||
if block_reader is True and num_shards is not None:
|
||||
raise ValueError("block reader not allowed true when use partitions")
|
||||
|
||||
if block_reader is True and shuffle is True:
|
||||
|
@ -2730,25 +2725,21 @@ class MindDataset(SourceDataset):
|
|||
logger.warning("WARN: global shuffle is not used.")
|
||||
|
||||
if sampler is not None:
|
||||
if isinstance(sampler, samplers.SubsetRandomSampler) is False and \
|
||||
isinstance(sampler, samplers.PKSampler) is False:
|
||||
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
|
||||
samplers.DistributedSampler, samplers.RandomSampler,
|
||||
samplers.SequentialSampler)) is False:
|
||||
raise ValueError("the sampler is not supported yet.")
|
||||
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
|
||||
# sampler exclusive
|
||||
if block_reader is True and sampler is not None:
|
||||
raise ValueError("block reader not allowed true when use sampler")
|
||||
|
||||
if shuffle is not None and sampler is not None:
|
||||
raise ValueError("shuffle not allowed when use sampler")
|
||||
|
||||
if block_reader is False and sampler is None:
|
||||
self.shuffle_option = not bool(shuffle is False)
|
||||
|
||||
if num_padded is None:
|
||||
num_padded = 0
|
||||
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.block_reader = block_reader
|
||||
self.padded_sample = padded_sample
|
||||
self.num_padded = num_padded
|
||||
|
@ -2766,10 +2757,8 @@ class MindDataset(SourceDataset):
|
|||
args["load_dataset"] = self.load_dataset
|
||||
args["columns_list"] = self.columns_list
|
||||
args["shuffle_option"] = self.shuffle_option
|
||||
args["partitions"] = self.partitions
|
||||
args["num_samples"] = self.num_samples
|
||||
args["block_reader"] = self.block_reader
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["num_padded"] = self.num_padded
|
||||
args["padded_sample"] = padded_sample
|
||||
args["sampler"] = self.sampler
|
||||
|
@ -2788,14 +2777,6 @@ class MindDataset(SourceDataset):
|
|||
else:
|
||||
dataset_file = self.dataset_file
|
||||
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
|
||||
if self.partitions is not None and self.partitions[0] > 0:
|
||||
if num_rows % self.partitions[0] == 0:
|
||||
num_rows = num_rows // self.partitions[0]
|
||||
else:
|
||||
if self.num_padded > 0:
|
||||
raise RuntimeError(
|
||||
"Dataset size plus number of padded samples is not divisible by number of shards.")
|
||||
num_rows = num_rows // self.partitions[0] + 1
|
||||
return num_rows
|
||||
return self._dataset_size
|
||||
|
||||
|
|
|
@ -141,7 +141,12 @@ class BuiltinSampler:
|
|||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create()
|
||||
return c_child_sampler
|
||||
|
||||
def create_child_for_minddataset(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create_for_minddataset()
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
|
@ -262,6 +267,12 @@ class DistributedSampler(BuiltinSampler):
|
|||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return self.shuffle
|
||||
|
@ -318,7 +329,7 @@ class PKSampler(BuiltinSampler):
|
|||
|
||||
self.num_val = num_val
|
||||
self.shuffle = shuffle
|
||||
self.class_column = class_column # work for minddataset
|
||||
self.class_column = class_column # work for minddataset
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
|
@ -340,12 +351,14 @@ class PKSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
def create_for_minddataset(self):
|
||||
if not self.class_column or not isinstance(self.class_column, str):
|
||||
raise ValueError("class_column should be a not empty string value, \
|
||||
but got class_column={}".format(class_column))
|
||||
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||
|
||||
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
class RandomSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -390,6 +403,13 @@ class RandomSampler(BuiltinSampler):
|
|||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
|
@ -440,6 +460,14 @@ class SequentialSampler(BuiltinSampler):
|
|||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
start_index = self.start_index if self.start_index is not None else 0
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
@ -501,8 +529,11 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
def create_for_minddataset(self):
|
||||
c_sampler = cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def get_num_samples(self):
|
||||
num_samples = super().get_num_samples()
|
||||
|
|
|
@ -17,6 +17,7 @@ This is the test module for mindrecord
|
|||
"""
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
@ -64,10 +65,12 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
|
||||
|
@ -82,12 +85,14 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[data]: \
|
||||
{}------------------------".format(item["data"][:10]))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
|
||||
|
@ -102,10 +107,12 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 9
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
|
||||
|
@ -119,10 +126,12 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 15
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
|
||||
|
@ -219,7 +228,6 @@ def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file
|
|||
|
||||
|
||||
def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
indices = [1, 2, 4, -1, -2]
|
||||
|
@ -241,6 +249,344 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
|
|||
assert num_iter == 5
|
||||
|
||||
|
||||
def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.RandomSampler()
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
new_dataset = []
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
new_dataset.append(item['file_name'])
|
||||
assert num_iter == 10
|
||||
assert new_dataset != [x['file_name'] for x in data]
|
||||
|
||||
def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file):
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.RandomSampler()
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
ds1 = data_set.repeat(3)
|
||||
num_iter = 0
|
||||
epoch1_dataset = []
|
||||
epoch2_dataset = []
|
||||
epoch3_dataset = []
|
||||
for item in ds1.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
if num_iter <= 10:
|
||||
epoch1_dataset.append(item['file_name'])
|
||||
elif num_iter <= 20:
|
||||
epoch2_dataset.append(item['file_name'])
|
||||
else:
|
||||
epoch3_dataset.append(item['file_name'])
|
||||
assert num_iter == 30
|
||||
assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
|
||||
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
|
||||
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
|
||||
|
||||
def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.RandomSampler(replacement=True, num_samples=5)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 5
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
|
||||
def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.SequentialSampler(1, 4)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 4
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(
|
||||
data[num_iter+1]['file_name'], dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.SequentialSampler(2, 10)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
dataset_size = data_set.get_dataset_size()
|
||||
assert dataset_size == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(
|
||||
data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_cv_minddataset_split_basic(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
|
||||
num_readers, shuffle=False)
|
||||
d1, d2 = d.split([8, 2], randomize=False)
|
||||
assert d.get_dataset_size() == 10
|
||||
assert d1.get_dataset_size() == 8
|
||||
assert d2.get_dataset_size() == 2
|
||||
num_iter = 0
|
||||
for item in d1.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(data[num_iter]['file_name'],
|
||||
dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 8
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
|
||||
dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
|
||||
num_readers, shuffle=False)
|
||||
d1, d2 = d.split([0.8, 0.2], randomize=False)
|
||||
assert d.get_dataset_size() == 10
|
||||
assert d1.get_dataset_size() == 8
|
||||
assert d2.get_dataset_size() == 2
|
||||
num_iter = 0
|
||||
for item in d1.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(
|
||||
data[num_iter]['file_name'], dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 8
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
|
||||
dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
|
||||
def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
|
||||
num_readers, shuffle=False)
|
||||
d1, d2 = d.split([0.41, 0.59], randomize=False)
|
||||
assert d.get_dataset_size() == 10
|
||||
assert d1.get_dataset_size() == 4
|
||||
assert d2.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in d1.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(
|
||||
data[num_iter]['file_name'], dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 4
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(data[num_iter + 4]['file_name'],
|
||||
dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
|
||||
def test_cv_minddataset_split_deterministic(add_and_remove_cv_file):
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
|
||||
num_readers, shuffle=False)
|
||||
# should set seed to avoid data overlap
|
||||
ds.config.set_seed(111)
|
||||
d1, d2 = d.split([0.8, 0.2])
|
||||
assert d.get_dataset_size() == 10
|
||||
assert d1.get_dataset_size() == 8
|
||||
assert d2.get_dataset_size() == 2
|
||||
|
||||
d1_dataset = []
|
||||
d2_dataset = []
|
||||
num_iter = 0
|
||||
for item in d1.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
d1_dataset.append(item['file_name'])
|
||||
num_iter += 1
|
||||
assert num_iter == 8
|
||||
num_iter = 0
|
||||
for item in d2.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
d2_dataset.append(item['file_name'])
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
inter_dataset = [x for x in d1_dataset if x in d2_dataset]
|
||||
assert inter_dataset == [] # intersection of d1 and d2
|
||||
|
||||
|
||||
def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
|
||||
num_readers, shuffle=False)
|
||||
# should set seed to avoid data overlap
|
||||
ds.config.set_seed(111)
|
||||
d1, d2 = d.split([0.8, 0.2])
|
||||
assert d.get_dataset_size() == 10
|
||||
assert d1.get_dataset_size() == 8
|
||||
assert d2.get_dataset_size() == 2
|
||||
distributed_sampler = ds.DistributedSampler(2, 0)
|
||||
d1.use_sampler(distributed_sampler)
|
||||
assert d1.get_dataset_size() == 4
|
||||
|
||||
num_iter = 0
|
||||
d1_shard1 = []
|
||||
for item in d1.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
d1_shard1.append(item['file_name'])
|
||||
assert num_iter == 4
|
||||
assert d1_shard1 != [x['file_name'] for x in data[0:4]]
|
||||
|
||||
distributed_sampler = ds.DistributedSampler(2, 1)
|
||||
d1.use_sampler(distributed_sampler)
|
||||
assert d1.get_dataset_size() == 4
|
||||
|
||||
d1s = d1.repeat(3)
|
||||
epoch1_dataset = []
|
||||
epoch2_dataset = []
|
||||
epoch3_dataset = []
|
||||
num_iter = 0
|
||||
for item in d1s.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
if num_iter <= 4:
|
||||
epoch1_dataset.append(item['file_name'])
|
||||
elif num_iter <= 8:
|
||||
epoch2_dataset.append(item['file_name'])
|
||||
else:
|
||||
epoch3_dataset.append(item['file_name'])
|
||||
assert len(epoch1_dataset) == 4
|
||||
assert len(epoch2_dataset) == 4
|
||||
assert len(epoch3_dataset) == 4
|
||||
inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset]
|
||||
assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2
|
||||
assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
|
||||
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
|
||||
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
|
||||
|
||||
|
||||
def get_data(dir_name, sampler=False):
|
||||
"""
|
||||
usage: get data from imagenet dataset
|
||||
|
|
Loading…
Reference in New Issue