forked from mindspore-Ecosystem/mindspore
Fixed repeat issue with GetDatasetSize
This commit is contained in:
parent
cacebd1211
commit
5005b434e6
|
@ -194,7 +194,7 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
|||
|
||||
// Get Dataset size
|
||||
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0 || num_repeats_ == -1) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -481,10 +481,13 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
|
|||
*num_classes = num_classes_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t classes_count;
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseManifestFile());
|
||||
*num_classes = num_classes_;
|
||||
classes_count = static_cast<int64_t>(op->label_index_.size());
|
||||
*num_classes = classes_count;
|
||||
num_classes_ = classes_count;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -79,8 +79,6 @@ TEST_F(MindDataTestPipeline, TestAlbumgetters) {
|
|||
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
int64_t dataset_size = ds->GetDatasetSize();
|
||||
EXPECT_EQ(dataset_size, 7);
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
int64_t batch_size = ds->GetBatchSize();
|
||||
|
|
|
@ -99,6 +99,9 @@ TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
|
|||
std::shared_ptr<Dataset> ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->Batch(2);
|
||||
ds = ds->Repeat(2);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 6);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue