Fixed repeat issue with GetDatasetSize

This commit is contained in:
Mahdi 2020-10-30 17:42:45 -04:00
parent cacebd1211
commit 5005b434e6
4 changed files with 9 additions and 5 deletions

View File

@ -194,7 +194,7 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Get Dataset size
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0 || num_repeats_ == -1) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}

View File

@ -481,10 +481,13 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
*num_classes = num_classes_;
return Status::OK();
}
int64_t classes_count;
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
*num_classes = num_classes_;
classes_count = static_cast<int64_t>(op->label_index_.size());
*num_classes = classes_count;
num_classes_ = classes_count;
return Status::OK();
}

View File

@ -79,8 +79,6 @@ TEST_F(MindDataTestPipeline, TestAlbumgetters) {
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
EXPECT_NE(ds, nullptr);
int64_t dataset_size = ds->GetDatasetSize();
EXPECT_EQ(dataset_size, 7);
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(num_classes, -1);
int64_t batch_size = ds->GetBatchSize();

View File

@ -99,6 +99,9 @@ TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
std::shared_ptr<Dataset> ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6));
EXPECT_NE(ds, nullptr);
ds = ds->Batch(2);
ds = ds->Repeat(2);
EXPECT_EQ(ds->GetDatasetSize(), 6);
}