!13198 Fix get_dataset_size error with PKSampler

From: @ziruiwu
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-12 06:11:26 +08:00 committed by Gitee
commit f7040ca134
15 changed files with 72 additions and 4 deletions

View File

@ -66,6 +66,11 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief PK cannot return an exact value because num_classes is not known until runtime, hence -1 is used
/// \param[out] num_rows
/// \return -1, which means PKSampler doesn't know how much data
int64_t CalculateNumSamples(int64_t num_rows) override { return -1; }
private:
bool shuffle_;
uint32_t seed_;

View File

@ -140,6 +140,8 @@ int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
// return -1 if child_num_rows is undetermined
if (child_num_rows == -1) return child_num_rows;
}
return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;

View File

@ -108,7 +108,7 @@ class SamplerRT {
// Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
// num_samples_
// @return number of samples
// @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler)
virtual int64_t CalculateNumSamples(int64_t num_rows);
// setter for num or records in the dataset

View File

@ -109,6 +109,8 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
// return -1 if child_num_rows is undetermined
if (child_num_rows == -1) return child_num_rows;
}
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows

View File

@ -139,6 +139,8 @@ int64_t SubsetSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
// return -1 if child_num_rows is undetermined
if (child_num_rows == -1) return child_num_rows;
}
int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
res = std::min(res, static_cast<int64_t>(indices_.size()));

View File

@ -144,6 +144,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
return Status::OK();
}

View File

@ -95,7 +95,9 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -88,12 +88,17 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -151,6 +151,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -100,6 +100,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -123,6 +123,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -88,6 +88,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -139,6 +139,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -36,7 +36,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
sampl = std::make_shared<PKSamplerObj>(3, false, 0);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), -1);
sampl = std::make_shared<RandomSamplerObj>(false, 12);
EXPECT_NE(sampl, nullptr);
@ -98,7 +98,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
std::shared_ptr<SamplerRT> sampler_rt6;
sampl6->SamplerBuild(&sampler_rt6);
sampler_rt6->AddChild(sampler_rt5);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), -1);
}
TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) {

View File

@ -501,6 +501,35 @@ def test_cifar_exception_file_path():
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
def test_cifar10_pk_sampler_get_dataset_size():
"""
Test Cifar10Dataset with PKSampler and get_dataset_size
"""
sampler = ds.PKSampler(3)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
ds_sz = data.get_dataset_size()
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert ds_sz == num_iter == 30
def test_cifar10_with_chained_sampler_get_dataset_size():
"""
Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
"""
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
child_sampler = ds.PKSampler(4)
sampler.add_child(child_sampler)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
ds_sz = data.get_dataset_size()
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert ds_sz == num_iter == 5
if __name__ == '__main__':
test_cifar10_content_check()
test_cifar10_basic()
@ -517,3 +546,6 @@ if __name__ == '__main__':
test_cifar_usage()
test_cifar_exception_file_path()
test_cifar10_with_chained_sampler_get_dataset_size()
test_cifar10_pk_sampler_get_dataset_size()