forked from mindspore-Ecosystem/mindspore
!273 [MD] update subset random sampler in minddataset
Merge pull request !273 from liyong126/mindrecord_subset_sampler_python
This commit is contained in:
commit
f1fa2a9941
|
@ -391,30 +391,6 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DEPipeline::GetMindrecordSampler(const std::string &sampler_name, const py::dict &args,
|
|
||||||
std::shared_ptr<mindrecord::ShardOperator> *ptr) {
|
|
||||||
std::vector<int> indices;
|
|
||||||
for (auto &arg : args) {
|
|
||||||
std::string key = py::str(arg.first);
|
|
||||||
py::handle value = arg.second;
|
|
||||||
if (!value.is_none()) {
|
|
||||||
if (key == "indices") {
|
|
||||||
indices = ToIntVector(value);
|
|
||||||
} else {
|
|
||||||
std::string err_msg = "ERROR: parameter " + key + " is invalid.";
|
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sampler_name == "SubsetRandomSampler") {
|
|
||||||
*ptr = std::make_shared<mindrecord::ShardSample>(indices);
|
|
||||||
} else {
|
|
||||||
std::string err_msg = "ERROR: parameter sampler_name is invalid.";
|
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||||
if (args["dataset_file"].is_none()) {
|
if (args["dataset_file"].is_none()) {
|
||||||
std::string err_msg = "Error: at least one of dataset_files is missing";
|
std::string err_msg = "Error: at least one of dataset_files is missing";
|
||||||
|
@ -446,12 +422,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
||||||
} else if (key == "global_shuffle" && ToBool(value) == true) {
|
} else if (key == "global_shuffle" && ToBool(value) == true) {
|
||||||
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
|
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
|
||||||
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
|
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
|
||||||
} else if (key == "sampler_name") {
|
} else if (key == "sampler") {
|
||||||
std::shared_ptr<mindrecord::ShardOperator> sample_op;
|
auto create = py::reinterpret_borrow<py::object>(value).attr("_create_for_minddataset");
|
||||||
auto ret = GetMindrecordSampler(ToString(value), args["sampler_params"], &sample_op);
|
std::shared_ptr<mindrecord::ShardOperator> sample_op =
|
||||||
if (Status::OK() != ret) {
|
create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
operators.push_back(sample_op);
|
operators.push_back(sample_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,9 +145,6 @@ class DEPipeline {
|
||||||
|
|
||||||
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||||
|
|
||||||
Status GetMindrecordSampler(const std::string &sampler_name, const py::dict &args,
|
|
||||||
std::shared_ptr<mindrecord::ShardOperator> *ptr);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Execution tree that links the dataset operators.
|
// Execution tree that links the dataset operators.
|
||||||
std::shared_ptr<ExecutionTree> tree_;
|
std::shared_ptr<ExecutionTree> tree_;
|
||||||
|
|
|
@ -54,6 +54,9 @@
|
||||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||||
#include "dataset/engine/jagged_connector.h"
|
#include "dataset/engine/jagged_connector.h"
|
||||||
#include "dataset/kernels/data/to_float16_op.h"
|
#include "dataset/kernels/data/to_float16_op.h"
|
||||||
|
#include "dataset/util/random.h"
|
||||||
|
#include "mindrecord/include/shard_operator.h"
|
||||||
|
#include "mindrecord/include/shard_sample.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
#include "pybind11/stl_bind.h"
|
#include "pybind11/stl_bind.h"
|
||||||
|
@ -382,6 +385,7 @@ void bindTensorOps4(py::module *m) {
|
||||||
|
|
||||||
void bindSamplerOps(py::module *m) {
|
void bindSamplerOps(py::module *m) {
|
||||||
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler");
|
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler");
|
||||||
|
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||||
|
|
||||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
||||||
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
|
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
|
||||||
|
@ -399,6 +403,10 @@ void bindSamplerOps(py::module *m) {
|
||||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
||||||
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
||||||
|
|
||||||
|
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
||||||
|
*m, "MindrecordSubsetRandomSampler")
|
||||||
|
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
||||||
|
|
||||||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
||||||
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
||||||
py::arg("replacement"));
|
py::arg("replacement"));
|
||||||
|
|
|
@ -32,7 +32,7 @@ class ShardCategory : public ShardOperator {
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &get_categories() const;
|
const std::vector<std::pair<std::string, std::string>> &get_categories() const;
|
||||||
|
|
||||||
MSRStatus operator()(ShardTask &tasks) override;
|
MSRStatus execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::pair<std::string, std::string>> categories_;
|
std::vector<std::pair<std::string, std::string>> categories_;
|
||||||
|
|
|
@ -24,7 +24,25 @@ namespace mindrecord {
|
||||||
class ShardOperator {
|
class ShardOperator {
|
||||||
public:
|
public:
|
||||||
virtual ~ShardOperator() = default;
|
virtual ~ShardOperator() = default;
|
||||||
virtual MSRStatus operator()(ShardTask &tasks) = 0;
|
|
||||||
|
MSRStatus operator()(ShardTask &tasks) {
|
||||||
|
if (SUCCESS != this->pre_execute(tasks)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (SUCCESS != this->execute(tasks)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (SUCCESS != this->suf_execute(tasks)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; }
|
||||||
|
|
||||||
|
virtual MSRStatus execute(ShardTask &tasks) = 0;
|
||||||
|
|
||||||
|
virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; }
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -17,10 +17,12 @@
|
||||||
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
||||||
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "mindrecord/include/shard_operator.h"
|
#include "mindrecord/include/shard_operator.h"
|
||||||
|
#include "mindrecord/include/shard_shuffle.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
|
@ -32,21 +34,23 @@ class ShardSample : public ShardOperator {
|
||||||
|
|
||||||
ShardSample(int num, int den, int par);
|
ShardSample(int num, int den, int par);
|
||||||
|
|
||||||
explicit ShardSample(const std::vector<int> &indices);
|
ShardSample(const std::vector<int64_t> &indices, uint32_t seed);
|
||||||
|
|
||||||
~ShardSample() override{};
|
~ShardSample() override{};
|
||||||
|
|
||||||
const std::pair<int, int> get_partitions() const;
|
const std::pair<int, int> get_partitions() const;
|
||||||
|
|
||||||
MSRStatus operator()(ShardTask &tasks) override;
|
MSRStatus execute(ShardTask &tasks) override;
|
||||||
|
MSRStatus suf_execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int numerator_;
|
int numerator_;
|
||||||
int denominator_;
|
int denominator_;
|
||||||
int no_of_samples_;
|
int no_of_samples_;
|
||||||
int partition_id_;
|
int partition_id_;
|
||||||
std::vector<int> indices_;
|
std::vector<int64_t> indices_;
|
||||||
SamplerType sampler_type_;
|
SamplerType sampler_type_;
|
||||||
|
std::shared_ptr<ShardShuffle> shuffle_op_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator {
|
||||||
|
|
||||||
~ShardShuffle() override{};
|
~ShardShuffle() override{};
|
||||||
|
|
||||||
MSRStatus operator()(ShardTask &tasks) override;
|
MSRStatus execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint32_t shuffle_seed_;
|
uint32_t shuffle_seed_;
|
||||||
|
|
|
@ -779,8 +779,12 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
|
||||||
|
|
||||||
// Sort row group by (group_id, shard_id), prepare for parallel reading
|
// Sort row group by (group_id, shard_id), prepare for parallel reading
|
||||||
std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
|
std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
|
||||||
CreateTasks(row_group_summary, operators_);
|
if (CreateTasks(row_group_summary, operators_) != SUCCESS) {
|
||||||
MS_LOG(INFO) << "Launching read threads";
|
MS_LOG(ERROR) << "Failed to launch read threads.";
|
||||||
|
interrupt_ = true;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Launching read threads.";
|
||||||
|
|
||||||
if (isSimpleReader) return SUCCESS;
|
if (isSimpleReader) return SUCCESS;
|
||||||
|
|
||||||
|
@ -1152,6 +1156,9 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetBlockNext()
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
|
std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
|
||||||
|
if (interrupt_) {
|
||||||
|
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
|
||||||
|
}
|
||||||
if (block_reader_) return GetBlockNext();
|
if (block_reader_) return GetBlockNext();
|
||||||
if (deliver_id_ >= static_cast<int>(tasks_.Size())) {
|
if (deliver_id_ >= static_cast<int>(tasks_.Size())) {
|
||||||
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
|
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
|
||||||
|
|
|
@ -23,6 +23,6 @@ ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::strin
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; }
|
const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; }
|
||||||
|
|
||||||
MSRStatus ShardCategory::operator()(ShardTask &tasks) { return SUCCESS; }
|
MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -46,13 +46,15 @@ ShardSample::ShardSample(int num, int den, int par)
|
||||||
indices_({}),
|
indices_({}),
|
||||||
sampler_type_(kCustomTopPercentSampler) {}
|
sampler_type_(kCustomTopPercentSampler) {}
|
||||||
|
|
||||||
ShardSample::ShardSample(const std::vector<int> &indices)
|
ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed)
|
||||||
: numerator_(0),
|
: numerator_(0),
|
||||||
denominator_(0),
|
denominator_(0),
|
||||||
no_of_samples_(0),
|
no_of_samples_(0),
|
||||||
partition_id_(0),
|
partition_id_(0),
|
||||||
indices_(indices),
|
indices_(indices),
|
||||||
sampler_type_(kSubsetRandomSampler) {}
|
sampler_type_(kSubsetRandomSampler) {
|
||||||
|
shuffle_op_ = std::make_shared<ShardShuffle>(seed);
|
||||||
|
}
|
||||||
|
|
||||||
const std::pair<int, int> ShardSample::get_partitions() const {
|
const std::pair<int, int> ShardSample::get_partitions() const {
|
||||||
if (numerator_ == 1 && denominator_ > 1) {
|
if (numerator_ == 1 && denominator_ > 1) {
|
||||||
|
@ -61,7 +63,7 @@ const std::pair<int, int> ShardSample::get_partitions() const {
|
||||||
return std::pair<int, int>(-1, -1);
|
return std::pair<int, int>(-1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardSample::operator()(ShardTask &tasks) {
|
MSRStatus ShardSample::execute(ShardTask &tasks) {
|
||||||
int no_of_categories = static_cast<int>(tasks.categories);
|
int no_of_categories = static_cast<int>(tasks.categories);
|
||||||
int total_no = static_cast<int>(tasks.Size());
|
int total_no = static_cast<int>(tasks.Size());
|
||||||
|
|
||||||
|
@ -115,5 +117,14 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardSample::suf_execute(ShardTask &tasks) {
|
||||||
|
if (sampler_type_ == kSubsetRandomSampler) {
|
||||||
|
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,7 +22,7 @@ namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {}
|
ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {}
|
||||||
|
|
||||||
MSRStatus ShardShuffle::operator()(ShardTask &tasks) {
|
MSRStatus ShardShuffle::execute(ShardTask &tasks) {
|
||||||
if (tasks.categories < 1) {
|
if (tasks.categories < 1) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1683,9 +1683,7 @@ class MindDataset(SourceDataset):
|
||||||
args["block_reader"] = self.block_reader
|
args["block_reader"] = self.block_reader
|
||||||
args["num_shards"] = self.num_shards
|
args["num_shards"] = self.num_shards
|
||||||
args["shard_id"] = self.shard_id
|
args["shard_id"] = self.shard_id
|
||||||
if self.sampler:
|
args["sampler"] = self.sampler
|
||||||
args["sampler_name"] = self.sampler.__class__.__name__
|
|
||||||
args["sampler_params"] = self.sampler.__dict__
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
|
|
@ -195,6 +195,8 @@ class SubsetRandomSampler():
|
||||||
def create(self):
|
def create(self):
|
||||||
return cde.SubsetRandomSampler(self.indices)
|
return cde.SubsetRandomSampler(self.indices)
|
||||||
|
|
||||||
|
def _create_for_minddataset(self):
|
||||||
|
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||||
|
|
||||||
class WeightedRandomSampler():
|
class WeightedRandomSampler():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -30,9 +30,9 @@
|
||||||
#include "mindrecord/include/shard_shuffle.h"
|
#include "mindrecord/include/shard_shuffle.h"
|
||||||
#include "ut_common.h"
|
#include "ut_common.h"
|
||||||
|
|
||||||
using mindspore::MsLogLevel::INFO;
|
|
||||||
using mindspore::ExceptionType::NoExceptionType;
|
|
||||||
using mindspore::LogStream;
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
|
@ -65,31 +65,31 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
|
||||||
ASSERT_TRUE(i <= kSampleCount);
|
ASSERT_TRUE(i <= kSampleCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
||||||
// MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
||||||
//
|
|
||||||
// std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
// auto column_list = std::vector<std::string>{"file_name"};
|
auto column_list = std::vector<std::string>{"file_name"};
|
||||||
//
|
|
||||||
// const int kNum = 5;
|
const int kNum = 5;
|
||||||
// const int kDen = 0;
|
const int kDen = 0;
|
||||||
// std::vector<std::shared_ptr<ShardOperator>> ops;
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||||
// ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
|
||||||
//
|
|
||||||
// ShardReader dataset;
|
ShardReader dataset;
|
||||||
// dataset.Open(file_name, 4, column_list, ops);
|
dataset.Open(file_name, 4, column_list, ops);
|
||||||
// dataset.Launch();
|
dataset.Launch();
|
||||||
//
|
|
||||||
// int i = 0;
|
int i = 0;
|
||||||
// while (true) {
|
while (true) {
|
||||||
// auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
// if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
// MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
||||||
// i++;
|
i++;
|
||||||
// }
|
}
|
||||||
// dataset.Finish();
|
dataset.Finish();
|
||||||
// ASSERT_TRUE(i <= 5);
|
ASSERT_TRUE(i <= 5);
|
||||||
// }
|
}
|
||||||
|
|
||||||
TEST_F(TestShardOperator, TestShardSampleRatio) {
|
TEST_F(TestShardOperator, TestShardSampleRatio) {
|
||||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
||||||
|
@ -117,7 +117,6 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
|
||||||
ASSERT_TRUE(i <= 10);
|
ASSERT_TRUE(i <= 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(TestShardOperator, TestShardSamplePartition) {
|
TEST_F(TestShardOperator, TestShardSamplePartition) {
|
||||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
@ -170,8 +169,8 @@ TEST_F(TestShardOperator, TestShardCategory) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
|
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
|
|
||||||
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
||||||
|
@ -199,8 +198,8 @@ TEST_F(TestShardOperator, TestShardShuffle) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
dataset.Finish();
|
dataset.Finish();
|
||||||
|
@ -224,8 +223,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
dataset.Finish();
|
dataset.Finish();
|
||||||
|
@ -251,8 +250,8 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
dataset.Finish();
|
dataset.Finish();
|
||||||
|
@ -278,8 +277,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
dataset.Finish();
|
dataset.Finish();
|
||||||
|
@ -307,8 +306,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
|
|
||||||
auto y = compare_dataset.GetNext();
|
auto y = compare_dataset.GetNext();
|
||||||
|
@ -342,8 +341,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
|
|
||||||
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
||||||
|
@ -376,8 +375,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
||||||
category_no++;
|
category_no++;
|
||||||
|
@ -410,8 +409,8 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
|
|
||||||
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
||||||
|
@ -448,8 +447,8 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
|
||||||
while (true) {
|
while (true) {
|
||||||
auto x = dataset.GetNext();
|
auto x = dataset.GetNext();
|
||||||
if (x.empty()) break;
|
if (x.empty()) break;
|
||||||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
|
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||||
i++;
|
i++;
|
||||||
|
|
||||||
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
|
||||||
|
|
|
@ -81,8 +81,6 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
|
||||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
assert data[indices[num_iter]]['file_name'] == "".join(
|
|
||||||
[chr(x) for x in item['file_name']])
|
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == 5
|
assert num_iter == 5
|
||||||
|
|
||||||
|
@ -107,8 +105,6 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
|
||||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
assert data[indices[num_iter]]['file_name'] == "".join(
|
|
||||||
[chr(x) for x in item['file_name']])
|
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == 6
|
assert num_iter == 6
|
||||||
|
|
||||||
|
@ -133,8 +129,6 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
|
||||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
assert data[indices[num_iter]]['file_name'] == "".join(
|
|
||||||
[chr(x) for x in item['file_name']])
|
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
|
|
||||||
|
@ -159,8 +153,6 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
|
||||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
assert data[indices[num_iter] % len(data)]['file_name'] == "".join([
|
|
||||||
chr(x) for x in item['file_name']])
|
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == 5
|
assert num_iter == 5
|
||||||
|
|
||||||
|
@ -185,8 +177,6 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
|
||||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
assert data[indices[num_iter] % len(data)]['file_name'] == "".join([
|
|
||||||
chr(x) for x in item['file_name']])
|
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == 5
|
assert num_iter == 5
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue