add split in minddataset

This commit is contained in:
liyong 2020-06-08 15:35:52 +08:00
parent 6089d58d8d
commit d4f8f57c7e
17 changed files with 699 additions and 127 deletions

View File

@ -391,35 +391,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
return Status::OK();
}
Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) {
if (args["partitions"].is_none()) {
std::string err_msg = "Error: partitions is not set (None)";
RETURN_STATUS_UNEXPECTED(err_msg);
}
py::list list = py::reinterpret_borrow<py::list>(args["partitions"]);
for (auto l : list) {
if (!l.is_none()) {
in_partitions->push_back(ToInt(l));
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded) {
auto sampler = py::reinterpret_borrow<py::object>(handle);
auto create = sampler.attr("create_for_minddataset");
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
while (op != nullptr) {
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
if (sampler_op && num_padded > 0) {
sampler_op->SetNumPaddedSamples(num_padded);
stack_ops.push(sampler_op);
} else {
stack_ops.push(op);
}
op = op->GetChildOp();
}
if (in_partitions->size() != 2) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
while (!stack_ops.empty()) {
operators->push_back(stack_ops.top());
stack_ops.pop();
}
constexpr int kMaxPartitions = 1024;
if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
@ -460,34 +452,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
(void)builder->SetNumMindRecordWorkers(ToInt(value));
} else if (key == "block_reader" && ToBool(value) == true) {
(void)builder->SetBlockReader();
} else if (key == "shuffle_option" && ToBool(value) == true) {
if (!args["partitions"].is_none()) continue;
uint32_t seed = GetSeed();
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
} else if (key == "sampler") {
auto sampler = py::reinterpret_borrow<py::object>(value);
auto create = sampler.attr("_create_for_minddataset");
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
operators.push_back(op);
int num_padded = 0;
if (!args["num_padded"].is_none()) {
num_padded = ToInt(args["num_padded"]);
}
RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded));
}
}
}
std::vector<int> in_partitions;
if (!args["partitions"].is_none()) {
auto ret = CheckMindRecordPartitionInfo(args, &in_partitions);
if (Status::OK() != ret) {
return ret;
}
auto shuffle = ToBool(args["shuffle_option"]);
int num_padded = 0;
if (!args["num_padded"].is_none()) {
num_padded = ToInt(args["num_padded"]);
}
operators.push_back(
std::make_shared<mindrecord::ShardDistributedSample>(in_partitions[0], in_partitions[1], num_padded, shuffle, 0));
}
if (!operators.empty()) {
(void)builder->SetOperators(operators);
}

View File

@ -18,6 +18,7 @@
#include <iostream>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
@ -108,10 +109,12 @@ class DEPipeline {
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr);
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

View File

@ -71,6 +71,7 @@
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_sequential_sample.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
@ -165,8 +166,8 @@ void bindDatasetOps(py::module *m) {
const int64_t num_padded) {
int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "_create_for_minddataset")) {
auto create = sampler.attr("_create_for_minddataset");
if (py::hasattr(sampler, "create_for_minddataset")) {
auto create = sampler.attr("create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
@ -486,7 +487,9 @@ void bindSamplerOps(py::module *m) {
.def("add_child",
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
@ -518,6 +521,22 @@ void bindSamplerOps(py::module *m) {
}
}));
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>());
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
}));
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
.def(py::init([](int num_samples, int start_index) {
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
}));
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());

View File

@ -31,6 +31,10 @@ class ShardDistributedSample : public ShardSample {
public:
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);
void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }
~ShardDistributedSample() override{};
MSRStatus PreExecute(ShardTask &tasks) override;

View File

@ -17,6 +17,7 @@
#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
#include <memory>
#include "mindrecord/include/shard_task.h"
namespace mindspore {
@ -37,6 +38,14 @@ class ShardOperator {
}
return SUCCESS;
}
virtual bool HasChildOp() { return child_op_ != nullptr; }
virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) {
if (child_op != nullptr) child_op_ = child_op;
return SUCCESS;
}
virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }
virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }
@ -44,7 +53,10 @@ class ShardOperator {
virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; }
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
private:
std::shared_ptr<ShardOperator> child_op_ = nullptr;
};
} // namespace mindrecord
} // namespace mindspore

View File

@ -34,6 +34,7 @@
#include <memory>
#include <mutex>
#include <set>
#include <stack>
#include <string>
#include <thread>
#include <tuple>
@ -44,6 +45,7 @@
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_operator.h"

View File

@ -48,10 +48,10 @@ class ShardSample : public ShardOperator {
int numerator_;
int denominator_;
int partition_id_;
int no_of_samples_;
std::shared_ptr<ShardShuffle> shuffle_op_;
private:
int no_of_samples_;
std::vector<int64_t> indices_;
SamplerType sampler_type_;
};

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_
#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_sample.h"
namespace mindspore {
namespace mindrecord {
class ShardSequentialSample : public ShardSample {
public:
ShardSequentialSample(int n, int offset);
ShardSequentialSample(float per, float per_offset);
~ShardSequentialSample() override{};
MSRStatus Execute(ShardTask &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
private:
int offset_;
float per_;
float per_offset_;
};
} // namespace mindrecord
} // namespace mindspore
#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_

View File

@ -26,12 +26,20 @@ class ShardShuffle : public ShardOperator {
public:
explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory);
ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
ShuffleType shuffle_type = kShuffleSample);
~ShardShuffle() override{};
MSRStatus Execute(ShardTask &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
private:
uint32_t shuffle_seed_;
int64_t no_of_samples_;
bool replacement_;
bool reshuffle_each_epoch_;
ShuffleType shuffle_type_;
};
} // namespace mindrecord

View File

@ -792,24 +792,51 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
}
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded) {
const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) {
if (SUCCESS != Init(file_paths, load_dataset)) {
return FAILED;
}
int64_t num_samples = num_rows_;
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
num_samples = op->GetNumSamples(num_rows_, 0);
if (-1 == num_samples) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
return FAILED;
bool root = true;
std::stack<std::shared_ptr<ShardOperator>> stack_ops;
std::shared_ptr<ShardOperator> op(ops);
while (op != nullptr) {
stack_ops.push(op);
op = op->GetChildOp();
}
while (!stack_ops.empty()) {
op = stack_ops.top();
stack_ops.pop();
if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
num_samples = op->GetNumSamples(num_samples, 0);
if (num_padded > 0 && root == true) {
num_samples += num_padded;
MS_LOG(DEBUG) << "Padding samples work on shuffle sampler.";
root = false;
}
} else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_samples, num_classes);
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
if (root == true) {
sampler_op->SetNumPaddedSamples(num_padded);
num_samples = op->GetNumSamples(num_samples, 0);
if (-1 == num_samples) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
return FAILED;
}
root = false;
}
} else {
num_samples = op->GetNumSamples(num_samples, 0);
}
} else {
if (num_padded > 0) num_samples += num_padded;
}
} else {
if (num_padded > 0) num_samples += num_padded;
}
*count = num_samples;
return SUCCESS;
@ -1385,12 +1412,16 @@ void ShardReader::Reset() {
}
void ShardReader::ShuffleTask() {
if (block_reader_) return;
// exist shuffle and distributed sampler in ops, skip shuffle
bool has_sharding = false;
for (const auto &op : operators_) {
if (block_reader_) {
continue;
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
has_sharding = true;
}
if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
}
for (const auto &op : operators_) {
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
if (SUCCESS != (*op)(tasks_)) {
MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
}

View File

@ -31,6 +31,9 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
}
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}
int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_padded_samples_ <= 0) {
if (dataset_size % denominator_ == 0) {

View File

@ -0,0 +1,74 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindrecord/include/shard_sequential_sample.h"
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
ShardSequentialSample::ShardSequentialSample(int n, int offset)
: ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {}
ShardSequentialSample::ShardSequentialSample(float per, float per_offset)
: ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {}
int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
return dataset_size;
}
if (per_ > kEpsilon && per_ <= 1.0f) {
return dataset_size * kEpsilon;
}
return no_of_samples_;
}
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
int total_no = static_cast<int>(tasks.Size());
int taking;
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
taking = total_no;
} else if (per_ > kEpsilon && per_ <= 1.0f) {
taking = total_no * kEpsilon;
} else {
taking = no_of_samples_;
}
if (tasks.permutation_.empty()) {
ShardTask new_tasks;
total_no = static_cast<int>(tasks.Size());
for (int i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
}
std::swap(tasks, new_tasks);
} else { // shuffled
ShardTask new_tasks;
if (taking > static_cast<int>(tasks.permutation_.size())) {
return FAILED;
}
total_no = static_cast<int>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
}
std::swap(tasks, new_tasks);
}
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore

View File

@ -21,17 +21,52 @@
namespace mindspore {
namespace mindrecord {
ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
: shuffle_seed_(seed), shuffle_type_(shuffle_type) {}
: shuffle_seed_(seed),
no_of_samples_(0),
replacement_(false),
reshuffle_each_epoch_(true),
shuffle_type_(shuffle_type) {}
ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
ShuffleType shuffle_type)
: shuffle_seed_(seed),
no_of_samples_(no_of_samples),
replacement_(replacement),
reshuffle_each_epoch_(reshuffle_each_epoch),
shuffle_type_(shuffle_type) {}
int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (replacement_) {
return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
}
return dataset_size;
}
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
if (tasks.categories < 1) {
return FAILED;
}
if (shuffle_type_ == kShuffleSample) {
if (shuffle_type_ == kShuffleSample) { // shuffle each sample
if (tasks.permutation_.empty() == true) {
tasks.MakePerm();
}
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
if (replacement_ == true) {
ShardTask new_tasks;
if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size());
}
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
new_tasks.task_list_.reserve(no_of_samples_);
for (uint32_t i = 0; i < no_of_samples_; ++i) {
new_tasks.InsertTask(tasks.GetRandomTask());
}
std::swap(tasks, new_tasks);
} else {
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
}
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
uint32_t individual_size = tasks.Size() / tasks.categories;
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
@ -46,7 +81,7 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
}
}
}
shuffle_seed_++;
if (reshuffle_each_epoch_) shuffle_seed_++;
return SUCCESS;
}
} // namespace mindrecord

View File

@ -72,6 +72,7 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
return task_list_[dis(gen)];
}
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) {
ShardTask res;
if (category_tasks.empty()) return res;

View File

@ -1015,10 +1015,8 @@ class Dataset:
def get_distribution(output_dataset):
dev_id = 0
if isinstance(output_dataset, (MindDataset)):
return output_dataset.distribution, dev_id
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)):
sampler = output_dataset.sampler
if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id
@ -2670,7 +2668,7 @@ class MnistDataset(MappableDataset):
return self.sampler.is_sharded()
class MindDataset(SourceDataset):
class MindDataset(MappableDataset):
"""
A source dataset that reads from shard files and database.
@ -2687,11 +2685,13 @@ class MindDataset(SourceDataset):
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, sampler is exclusive
with shuffle and block_reader). Support list: SubsetRandomSampler,
PkSampler.
PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
padded_sample (dict, optional): Samples will be appended to dataset, which
keys are the same as column_list.
num_padded (int, optional): Number of padding samples.Dataset size
plus num_padded should be divisible by num_shards.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all samples).
Raises:
ValueError: If num_shards is specified but shard_id is None.
@ -2703,7 +2703,7 @@ class MindDataset(SourceDataset):
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
shuffle=None, num_shards=None, shard_id=None,
block_reader=False, sampler=None, padded_sample=None,
num_padded=None):
num_padded=None, num_samples=None):
super().__init__(num_parallel_workers)
if isinstance(dataset_file, list):
self.load_dataset = False
@ -2712,15 +2712,10 @@ class MindDataset(SourceDataset):
self.dataset_file = dataset_file
self.columns_list = columns_list
self.shuffle_option = shuffle
self.distribution = ""
self.sampler = sampler
self.num_shards = num_shards
self.shard_id = shard_id
if num_shards is None or shard_id is None:
self.partitions = None
else:
self.partitions = [num_shards, shard_id]
if block_reader is True and self.partitions is not None:
if block_reader is True and num_shards is not None:
raise ValueError("block reader not allowed true when use partitions")
if block_reader is True and shuffle is True:
@ -2730,25 +2725,21 @@ class MindDataset(SourceDataset):
logger.warning("WARN: global shuffle is not used.")
if sampler is not None:
if isinstance(sampler, samplers.SubsetRandomSampler) is False and \
isinstance(sampler, samplers.PKSampler) is False:
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
samplers.DistributedSampler, samplers.RandomSampler,
samplers.SequentialSampler)) is False:
raise ValueError("the sampler is not supported yet.")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
# sampler exclusive
if block_reader is True and sampler is not None:
raise ValueError("block reader not allowed true when use sampler")
if shuffle is not None and sampler is not None:
raise ValueError("shuffle not allowed when use sampler")
if block_reader is False and sampler is None:
self.shuffle_option = not bool(shuffle is False)
if num_padded is None:
num_padded = 0
self.num_shards = num_shards
self.shard_id = shard_id
self.block_reader = block_reader
self.padded_sample = padded_sample
self.num_padded = num_padded
@ -2766,10 +2757,8 @@ class MindDataset(SourceDataset):
args["load_dataset"] = self.load_dataset
args["columns_list"] = self.columns_list
args["shuffle_option"] = self.shuffle_option
args["partitions"] = self.partitions
args["num_samples"] = self.num_samples
args["block_reader"] = self.block_reader
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["num_padded"] = self.num_padded
args["padded_sample"] = padded_sample
args["sampler"] = self.sampler
@ -2788,14 +2777,6 @@ class MindDataset(SourceDataset):
else:
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows
return self._dataset_size

View File

@ -141,7 +141,12 @@ class BuiltinSampler:
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create()
return c_child_sampler
def create_child_for_minddataset(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create_for_minddataset()
return c_child_sampler
def is_shuffled(self):
@ -262,6 +267,12 @@ class DistributedSampler(BuiltinSampler):
c_sampler.add_child(c_child_sampler)
return c_sampler
def create_for_minddataset(self):
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
@ -318,7 +329,7 @@ class PKSampler(BuiltinSampler):
self.num_val = num_val
self.shuffle = shuffle
self.class_column = class_column # work for minddataset
self.class_column = class_column # work for minddataset
super().__init__(num_samples)
def create(self):
@ -340,12 +351,14 @@ class PKSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def _create_for_minddataset(self):
def create_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column))
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
class RandomSampler(BuiltinSampler):
"""
@ -390,6 +403,13 @@ class RandomSampler(BuiltinSampler):
c_sampler.add_child(c_child_sampler)
return c_sampler
def create_for_minddataset(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
@ -440,6 +460,14 @@ class SequentialSampler(BuiltinSampler):
c_sampler.add_child(c_child_sampler)
return c_sampler
def create_for_minddataset(self):
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
@ -501,8 +529,11 @@ class SubsetRandomSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)
def create_for_minddataset(self):
c_sampler = cde.MindrecordSubsetRandomSampler(self.indices)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def get_num_samples(self):
num_samples = super().get_num_samples()

View File

@ -17,6 +17,7 @@ This is the test module for mindrecord
"""
import os
import pytest
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
@ -64,10 +65,12 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -82,12 +85,14 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[data]: \
{}------------------------".format(item["data"][:10]))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -102,10 +107,12 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 9
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -119,10 +126,12 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 15
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -219,7 +228,6 @@ def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file
def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
indices = [1, 2, 4, -1, -2]
@ -241,6 +249,344 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
assert num_iter == 5
def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler()
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
num_iter = 0
new_dataset = []
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
new_dataset.append(item['file_name'])
assert num_iter == 10
assert new_dataset != [x['file_name'] for x in data]
def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler()
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
ds1 = data_set.repeat(3)
num_iter = 0
epoch1_dataset = []
epoch2_dataset = []
epoch3_dataset = []
for item in ds1.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
if num_iter <= 10:
epoch1_dataset.append(item['file_name'])
elif num_iter <= 20:
epoch2_dataset.append(item['file_name'])
else:
epoch3_dataset.append(item['file_name'])
assert num_iter == 30
assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler(replacement=True, num_samples=5)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 5
def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.SequentialSampler(1, 4)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 4
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[num_iter+1]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 4
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.SequentialSampler(2, 10)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
dataset_size = data_set.get_dataset_size()
assert dataset_size == 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_split_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
d1, d2 = d.split([8, 2], randomize=False)
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 8
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 2
def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
d1, d2 = d.split([0.8, 0.2], randomize=False)
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[num_iter]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 8
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 2
def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
d1, d2 = d.split([0.41, 0.59], randomize=False)
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 4
assert d2.get_dataset_size() == 6
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[num_iter]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 4
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter + 4]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 6
def test_cv_minddataset_split_deterministic(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
# should set seed to avoid data overlap
ds.config.set_seed(111)
d1, d2 = d.split([0.8, 0.2])
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
d1_dataset = []
d2_dataset = []
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
d1_dataset.append(item['file_name'])
num_iter += 1
assert num_iter == 8
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
d2_dataset.append(item['file_name'])
num_iter += 1
assert num_iter == 2
inter_dataset = [x for x in d1_dataset if x in d2_dataset]
assert inter_dataset == [] # intersection of d1 and d2
def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
# should set seed to avoid data overlap
ds.config.set_seed(111)
d1, d2 = d.split([0.8, 0.2])
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
distributed_sampler = ds.DistributedSampler(2, 0)
d1.use_sampler(distributed_sampler)
assert d1.get_dataset_size() == 4
num_iter = 0
d1_shard1 = []
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
d1_shard1.append(item['file_name'])
assert num_iter == 4
assert d1_shard1 != [x['file_name'] for x in data[0:4]]
distributed_sampler = ds.DistributedSampler(2, 1)
d1.use_sampler(distributed_sampler)
assert d1.get_dataset_size() == 4
d1s = d1.repeat(3)
epoch1_dataset = []
epoch2_dataset = []
epoch3_dataset = []
num_iter = 0
for item in d1s.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
if num_iter <= 4:
epoch1_dataset.append(item['file_name'])
elif num_iter <= 8:
epoch2_dataset.append(item['file_name'])
else:
epoch3_dataset.append(item['file_name'])
assert len(epoch1_dataset) == 4
assert len(epoch2_dataset) == 4
assert len(epoch3_dataset) == 4
inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset]
assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2
assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
def get_data(dir_name, sampler=False):
"""
usage: get data from imagenet dataset