forked from mindspore-Ecosystem/mindspore
Fixed GetDatasetSize for TextFile
This commit is contained in:
parent
fedb225a96
commit
449e1526dc
|
@ -529,7 +529,7 @@ Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
int64_t num_rows, sample_size;
|
||||
sample_size = total_rows_;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = total_rows_;
|
||||
num_rows = num_rows_per_shard_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -105,6 +105,10 @@ TEST_F(MindDataTestPipeline, TestTextFileGetters) {
|
|||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
|
||||
ds = TextFile({tf_file1}, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
|
|
Loading…
Reference in New Issue