forked from mindspore-Ecosystem/mindspore
!13198 Fix get_dataset_size error with PKSampler
From: @ziruiwu Reviewed-by: Signed-off-by:
This commit is contained in:
commit
f7040ca134
|
@ -66,6 +66,11 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief PK cannot return an exact value because num_classes is not known until runtime, hence -1 is used
|
||||
/// \param[out] num_rows
|
||||
/// \return -1, which means PKSampler doesn't know how much data
|
||||
int64_t CalculateNumSamples(int64_t num_rows) override { return -1; }
|
||||
|
||||
private:
|
||||
bool shuffle_;
|
||||
uint32_t seed_;
|
||||
|
|
|
@ -140,6 +140,8 @@ int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
|
|||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
// return -1 if child_num_rows is undetermined
|
||||
if (child_num_rows == -1) return child_num_rows;
|
||||
}
|
||||
|
||||
return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
|
|
|
@ -108,7 +108,7 @@ class SamplerRT {
|
|||
|
||||
// Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
|
||||
// num_samples_
|
||||
// @return number of samples
|
||||
// @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler)
|
||||
virtual int64_t CalculateNumSamples(int64_t num_rows);
|
||||
|
||||
// setter for num or records in the dataset
|
||||
|
|
|
@ -109,6 +109,8 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
|||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
// return -1 if child_num_rows is undetermined
|
||||
if (child_num_rows == -1) return child_num_rows;
|
||||
}
|
||||
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
|
||||
|
|
|
@ -139,6 +139,8 @@ int64_t SubsetSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
|||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
// return -1 if child_num_rows is undetermined
|
||||
if (child_num_rows == -1) return child_num_rows;
|
||||
}
|
||||
int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
res = std::min(res, static_cast<int64_t>(indices_.size()));
|
||||
|
|
|
@ -144,6 +144,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -95,7 +95,9 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -88,12 +88,17 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -151,6 +151,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -100,6 +100,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -123,6 +123,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -88,6 +88,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -139,6 +139,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
|
|||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -36,7 +36,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
|
|||
sampl = std::make_shared<PKSamplerObj>(3, false, 0);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
sampl->SamplerBuild(&sampler_rt);
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), -1);
|
||||
|
||||
sampl = std::make_shared<RandomSamplerObj>(false, 12);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
|
@ -98,7 +98,7 @@ TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
|
|||
std::shared_ptr<SamplerRT> sampler_rt6;
|
||||
sampl6->SamplerBuild(&sampler_rt6);
|
||||
sampler_rt6->AddChild(sampler_rt5);
|
||||
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
|
||||
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), -1);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) {
|
||||
|
|
|
@ -501,6 +501,35 @@ def test_cifar_exception_file_path():
|
|||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
|
||||
def test_cifar10_pk_sampler_get_dataset_size():
|
||||
"""
|
||||
Test Cifar10Dataset with PKSampler and get_dataset_size
|
||||
"""
|
||||
sampler = ds.PKSampler(3)
|
||||
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
|
||||
num_iter = 0
|
||||
ds_sz = data.get_dataset_size()
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
||||
assert ds_sz == num_iter == 30
|
||||
|
||||
|
||||
def test_cifar10_with_chained_sampler_get_dataset_size():
|
||||
"""
|
||||
Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
|
||||
"""
|
||||
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
|
||||
child_sampler = ds.PKSampler(4)
|
||||
sampler.add_child(child_sampler)
|
||||
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
|
||||
num_iter = 0
|
||||
ds_sz = data.get_dataset_size()
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert ds_sz == num_iter == 5
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cifar10_content_check()
|
||||
test_cifar10_basic()
|
||||
|
@ -517,3 +546,6 @@ if __name__ == '__main__':
|
|||
|
||||
test_cifar_usage()
|
||||
test_cifar_exception_file_path()
|
||||
|
||||
test_cifar10_with_chained_sampler_get_dataset_size()
|
||||
test_cifar10_pk_sampler_get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue