diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 84176651398..8ffbdd83b60 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -199,12 +199,13 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data // Constructor Dataset::Dataset() { tree_getters_ = std::make_shared(); } -int64_t Dataset::GetDatasetSize() { +int64_t Dataset::GetDatasetSize(bool estimate) { int64_t dataset_size; std::unique_ptr runtime_context = std::make_unique(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); - RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); - RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1); + std::shared_ptr size_getter = std::make_shared(); + RETURN_SECOND_IF_ERROR(size_getter->Init(this->IRNode()), -1); + RETURN_SECOND_IF_ERROR(size_getter->GetDatasetSize(&dataset_size, estimate), -1); return dataset_size; } diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc index e3de11d813f..8844cec53b6 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc @@ -106,19 +106,7 @@ PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) { })); PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) { - (void)py::class_>(*m, "ManifestOp") - .def_static("get_num_rows_and_classes", - [](const std::string &file, const py::dict &dict, const std::string &usage) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }) - .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, - const std::string &usage) { - std::map output_class_indexing; - THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); - return output_class_indexing; - }); + (void)py::class_>(*m, "ManifestOp"); })); PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "MindRecordOp") @@ -173,13 +161,6 @@ PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) { PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "VOCOp") - .def_static("get_num_rows", - [](const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples) { - int64_t count = 0; - THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); - return count; - }) .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict) { std::map output_class_indexing; diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc index c2bcb7d4264..996c7429d84 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc @@ -184,7 +184,11 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { auto gen = std::make_shared(generator_function, schema); THROW_IF_ERROR(gen->ValidateParams()); return gen; - })); + })) + .def("SetGeneratorDatasetSize", [](std::shared_ptr self, int64_t sz) { + self->SetGeneratorDatasetSize(sz); + return self; + }); })); PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/iterator_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/iterator_bindings.cc index 6ab9ac6d05e..36424e4b25c 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/iterator_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/iterator_bindings.cc @@ -93,12 +93,6 @@ PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) { THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing)); return output_class_indexing; }) - .def("GetDatasetSize", - [](PythonTreeGetters &self) { - int64_t dataset_size; - THROW_IF_ERROR(self.GetDatasetSize(&dataset_size)); - return dataset_size; - }) .def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; }); })); @@ -164,5 +158,18 @@ PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) { .def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); }); })); +PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DatasetSizeGetters") + .def(py::init<>()) + .def("Init", [](PythonDatasetSizeGetter &self, + std::shared_ptr d) { THROW_IF_ERROR(self.Init(d)); }) + .def("GetDatasetSize", [](PythonDatasetSizeGetter &self, bool estimate) { + int64_t size; + THROW_IF_ERROR(self.GetDatasetSize(&size, estimate)); + return size; + }); + })); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc index c728925dd0f..db36286df31 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc @@ -65,4 +65,8 @@ Status PythonTreeGetters::GetRow(TensorRow *r) { py::gil_scoped_release gil_release; return TreeGetters::GetRow(r); } +Status PythonDatasetSizeGetter::GetRow(const std::shared_ptr &tree_adapter, TensorRow *r) { + py::gil_scoped_release gil_release; + return DatasetSizeGetter::GetRow(tree_adapter, r); +} } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h index 94e5ad22346..b200fe06049 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h @@ -60,5 +60,9 @@ class PythonTreeGetters : public TreeGetters { public: Status GetRow(TensorRow *r) override; }; +class PythonDatasetSizeGetter : public DatasetSizeGetter { + public: + Status GetRow(const std::shared_ptr &tree_adapter, TensorRow *r) override; +}; } // namespace mindspore::dataset #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 49e107a9809..02a79c7dbda 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -451,29 +451,6 @@ Status TreeGetters::Init(std::shared_ptr d) { Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); } -Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ == -1) { - RETURN_IF_NOT_OK(InternalInit(static_cast(GetterPass::kDatasetSize))); - std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); - RETURN_UNEXPECTED_IF_NULL(root); - RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); - if (*dataset_size == -1) { // run through the tree and get everything - TensorRow row; - RETURN_IF_NOT_OK(GetRow(&row)); - int64_t row_cnt = 0; - while (!row.empty()) { - ++row_cnt; - RETURN_IF_NOT_OK(GetRow(&row)); - } - *dataset_size = row_cnt; - } - dataset_size_ = *dataset_size; // save the previous result - } - - *dataset_size = dataset_size_; - return Status::OK(); -} - Status TreeGetters::GetOutputTypes(std::vector *types) { RETURN_IF_NOT_OK(GetFirstRowShapeAndType()); *types = first_row_type_; @@ -573,5 +550,46 @@ Status BuildVocabConsumer::Start() { CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE."); return Status::OK(); } +Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) { + if (dataset_size_ == -1) { + RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size)); + dataset_size_ = *size; // save the previous result + } + + *size = dataset_size_; + return Status::OK(); +} +Status DatasetSizeGetter::Init(std::shared_ptr d) { + root_ = std::move(d); + return Status::OK(); +} +Status DatasetSizeGetter::DryRun(std::shared_ptr ir_node, int64_t *dataset_size) { + std::shared_ptr tree_adapter = std::make_shared(); + tree_adapters_.push_back(tree_adapter); + tree_adapter->SetPrePassOverride([](OptPass pre) { + pre.push_back( + std::make_unique(static_cast(GetterPass::GetterType::kDatasetSize))); + return pre; + }); + RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1)); + TensorRow row; + RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); + int64_t row_cnt = 0; + while (!row.empty()) { + ++row_cnt; + RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); + } + *dataset_size = row_cnt; + return Status::OK(); +} +Status DatasetSizeGetter::GetRow(const std::shared_ptr &tree_adapter, TensorRow *row) { + return tree_adapter->GetNext(row); +} +Status DatasetSizeGetter::Terminate() { + for (const auto &tree : tree_adapters_) { + RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop()); + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 9e770f2a458..c733eb2d419 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -177,7 +177,6 @@ class TreeGetters : public TreeConsumer { ~TreeGetters() = default; Status Init(std::shared_ptr d) override; - Status GetDatasetSize(int64_t *size); Status GetOutputTypes(std::vector *types); Status GetOutputShapes(std::vector *shapes); Status GetBatchSize(int64_t *batch_size); @@ -186,7 +185,7 @@ class TreeGetters : public TreeConsumer { Status GetColumnNames(std::vector *output); Status GetClassIndexing(std::vector>> *output_class_indexing); std::string Name() override { return "TreeGetters"; } - virtual Status GetRow(TensorRow *r); + virtual Status GetRow(TensorRow *row); private: Status GetFirstRowShapeAndType(); @@ -202,6 +201,35 @@ class TreeGetters : public TreeConsumer { Status InternalInit(); }; +/// Consumer that is used to get some pipeline information +class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this { + public: + DatasetSizeGetter() : dataset_size_(-1) {} + ~DatasetSizeGetter() = default; + Status Init(std::shared_ptr d) override; + Status Terminate() override; + + /// \brief Function to get the dataset size + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *size, bool estimate = false); + + virtual Status GetRow(const std::shared_ptr &tree_adapter, TensorRow *row); + std::string Name() override { return "DatasetSizeGetter"; } + + /// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node + /// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate + /// \return Status - The status code return + Status DryRun(std::shared_ptr ir_node, int64_t *dataset_size); + + private: + std::shared_ptr root_; + std::vector> tree_adapters_; + int64_t dataset_size_; +}; + class BuildVocabConsumer : public TreeConsumer { public: /// BuildVocabConsumer Constructor which will call the base class default constructor. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 5a016a0fcb9..60559c86409 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -537,30 +537,6 @@ Status BatchOp::ComputeColMap() { return Status::OK(); } -Status BatchOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } -#ifdef ENABLE_PYTHON - if (batch_size_func_) { - *dataset_size = -1; - return Status::OK(); - } -#endif - int64_t num_rows; - RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); - if (num_rows > 0 && start_batch_size_ > 0) { - if (drop_) { - num_rows = static_cast(floor(num_rows / (1.0 * start_batch_size_))); - } else { - num_rows = static_cast(ceil(num_rows / (1.0 * start_batch_size_))); - } - } - *dataset_size = num_rows; - dataset_size_ = num_rows; - return Status::OK(); -} int64_t BatchOp::GetTreeBatchSize() { #ifdef ENABLE_PYTHON if (batch_size_func_) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h index 94dd3e82022..ec25495bede 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -225,11 +225,6 @@ class BatchOp : public ParallelOp { static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, const std::unordered_map &column_name_id_map); - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - int64_t GetTreeBatchSize() override; protected: diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc index a8e6793c0fc..d3057353f3a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -232,12 +232,5 @@ Status BucketBatchByLengthOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) { - // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to - // iterate over the dataset and count the size - *dataset_size = dataset_size_; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h index 6ac3700baeb..3fd446322b1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -112,11 +112,6 @@ class BucketBatchByLengthOp : public PipelineOp { std::string Name() const override { return kBucketBatchByLengthOp; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - // << Stream output operator overload // @notes This allows you to write the debug print info using stream operators // @param out - reference to the output stream being overloaded diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index 380516d0e71..9cba1119aba 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -196,12 +196,5 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) { return p->PreRunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status ConcatOp::GetDatasetSize(int64_t *dataset_size) { - // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to - // iterate over the dataset and count the size - *dataset_size = dataset_size_; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h index e81484a0fac..90cf8d7a728 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -111,11 +111,6 @@ class ConcatOp : public PipelineOp { /// \return Status of the node visit Status PreAccept(NodePass *p, bool *modified) override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: Status Verify(int32_t id, const std::unique_ptr &buf); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 92727fb9067..402cd30888c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -294,24 +294,6 @@ Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t wo return Status::OK(); } -// Gets the dataset size -Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - if (child_.size() == 1) { - return child_[0]->GetDatasetSize(dataset_size); - } else if (child_.size() > 1) { - // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. - // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will - // always be in front of the child_ structure, so we get the dataset size from the last child. - return child_[child_.size() - 1]->GetDatasetSize(dataset_size); - } else { - RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); - } -} - // Gets the number of classes Status DatasetOp::GetNumClasses(int64_t *num_classes) { if (child_.size() == 1) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index fa04d07e8d6..e77cee56ff3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -180,10 +180,6 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Status - The error code return Status GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); - /// \brief Gets the dataset size - /// \return Status - The status code return - virtual Status GetDatasetSize(int64_t *dataset_size); - /// \brief Gets the batch size /// \return Status - The status code return virtual int64_t GetTreeBatchSize(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc index e02eafffb8b..19e518e9937 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -258,13 +258,5 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) { return p->PreRunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status FilterOp::GetDatasetSize(int64_t *dataset_size) { - // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to - // iterate over the dataset and count the size - *dataset_size = dataset_size_; - return Status::OK(); -} - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h index 5a19bbbd367..d1a884b45de 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h @@ -137,11 +137,6 @@ class FilterOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return kFilterOp; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // predicate_func python callable which returns a boolean value. std::shared_ptr predicate_func_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 9aa5e301e54..aef81e5526b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -187,21 +187,6 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status RepeatOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows; - RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); - if (num_rows > 0 && num_repeats_ > 0) { - num_rows = num_rows * num_repeats_; - } - *dataset_size = num_rows; - dataset_size_ = num_rows; - return Status::OK(); -} int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index 74339d725a2..19d4e432458 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -133,11 +133,6 @@ class RepeatOp : public PipelineOp { /// \@return Status - The error code return Status Reset() override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - int64_t GetTreeRepeatCount() override; // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc index cd5b5296b8c..a1986c700ae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -136,20 +136,5 @@ Status SkipOp::PreAccept(NodePass *p, bool *modified) { return p->PreRunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status SkipOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows; - RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); - *dataset_size = 0; - if (max_skips_ >= 0 && max_skips_ < num_rows) { - *dataset_size = num_rows - max_skips_; - } - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h index ec67c1e4a5b..983502caa32 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h @@ -86,11 +86,6 @@ class SkipOp : public PipelineOp { /// \return Status of the node visit Status PreAccept(NodePass *p, bool *modified) override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - // Op name getter // @return Name of the current Op std::string Name() const override { return kSkipOp; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 4bd915e86b6..b4b6e58cf73 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -452,63 +452,5 @@ Status CelebAOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status CelebAOp::GetDatasetSize(int64_t *dataset_size) { - int64_t num_rows, sample_size; - std::string line; - Path folder_path(folder_path_); - std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); - if (!attr_file.is_open()) { - std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString(); - RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name); - } - - std::string rows_num; - (void)getline(attr_file, rows_num); - try { - num_rows = static_cast(std::stoul(rows_num)); // First line is rows number in attr file - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED( - "Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED( - "Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num); - } - if (usage_ != "all") { - int64_t partition_num = 0; - char usage_type; - if (usage_ == "train") { - usage_type = '0'; - } else { - if (usage_ == "valid") { - usage_type = '1'; - } else { - if (usage_ == "test") - usage_type = '2'; - else - RETURN_STATUS_UNEXPECTED("Invalid usage."); - } - } - if (!partition_file_.is_open()) { - partition_file_.open((folder_path / "list_eval_partition.txt").toString()); - } - if (partition_file_.is_open()) { - while (getline(partition_file_, line)) { - int start = line.find(' '); - if (line.at(start + 1) == usage_type) { - partition_num++; - } - } - } else { - std::string partition_file_name = "list_eval_partition.txt"; - RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name); - } - num_rows = std::min(num_rows, partition_num); - } - - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h index 2a9de5e493c..cfadac0387a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -179,11 +179,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "CelebAOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // Called first when function is called // @return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index ca4b4b7f31f..aabf6da55ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -508,20 +508,5 @@ Status CifarOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status CifarOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - num_rows = num_rows_; - if (num_rows_ <= 0) - RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows)); - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h index 60ee4848f0e..450edddf0d0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -175,11 +175,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "CifarOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index f3e9bf40d4d..bc2c1c67ba0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -565,19 +565,5 @@ Status ClueOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status ClueOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - sample_size = num_samples_; - num_rows = num_rows_per_shard_; - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index 09b7c9e6a98..993b5491037 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -197,11 +197,6 @@ class ClueOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index 9bc25f53dd4..0fe0adf64b7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -681,39 +681,6 @@ Status CocoOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status CocoOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows = 0, sample_size; - std::string task_type; - switch (task_type_) { - case TaskType::Detection: - task_type = "Detection"; - break; - case TaskType::Keypoint: - task_type = "Keypoint"; - break; - case TaskType::Panoptic: - task_type = "Panoptic"; - break; - case TaskType::Stuff: - task_type = "Stuff"; - break; - } - if (image_ids_.size() == 0) { - RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); - } else { - num_rows = image_ids_.size(); - } - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - Status CocoOp::GetClassIndexing(std::vector>> *output_class_indexing) { if ((*output_class_indexing).empty()) { if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index bed009412e5..3ea33dcddc1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -213,11 +213,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "CocoOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - /// \brief Gets the class indexing /// \return Status - The status code return Status GetClassIndexing(std::vector>> *output_class_indexing) override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 333a374c3b2..c5bd29cb279 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -916,19 +916,5 @@ Status CsvOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status CsvOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - sample_size = num_samples_; - num_rows = num_rows_per_shard_; - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index 154d027744c..a6075a196a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -318,11 +318,6 @@ class CsvOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 18734c333c6..4bf2205744b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -274,11 +274,5 @@ Status GeneratorOp::ComputeColMap() { } return Status::OK(); } -Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size - // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to - // iterate over the dataset and count the size - *dataset_size = dataset_size_; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index 35d313733ac..ff451b0929c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -136,8 +136,6 @@ class GeneratorOp : public PipelineOp { Status Init(); - Status GetDatasetSize(int64_t *dataset_size) override; - private: py::function generator_function_; std::vector column_names_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 975364e3dd9..437c1c2a5af 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -465,24 +465,6 @@ Status ImageFolderOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t sample_size, num_rows; - num_rows = num_rows_; - if (num_rows_ <= 0) { - // GetDatasetSize will not be impacted by class_index_ - RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {})); - } - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - // Get number of classes Status ImageFolderOp::GetNumClasses(int64_t *num_classes) { if (num_classes_ > 0) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h index bd3e1d694c7..ff51cb46ebe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -217,11 +217,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "ImageFolderOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - /// \brief Base-class override for GetNumClasses /// \param[out] num_classes the number of classes /// \return Status of the function diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 2cf21e15553..722e9124d93 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -396,16 +396,9 @@ Status ManifestOp::CountDatasetInfo() { return Status::OK(); } -#ifdef ENABLE_PYTHON -Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, - int64_t *count, int64_t *numClasses) { +Status ManifestOp::CountTotalRows(const std::string &file, const std::map &map, + const std::string &usage, int64_t *count, int64_t *numClasses) { // the logic of counting the number of samples is copied from ParseManifestFile() - std::map map; - for (auto p : dict) { - (void)map.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - std::shared_ptr op; *count = 0; RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); @@ -415,6 +408,7 @@ Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, return Status::OK(); } +#ifdef ENABLE_PYTHON Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, std::map *output_class_indexing) { std::map input_class_indexing; @@ -459,23 +453,6 @@ Status ManifestOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status ManifestOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); - RETURN_IF_NOT_OK(op->ParseManifestFile()); - num_rows = static_cast(op->image_labelname_.size()); - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - // Get number of classes Status ManifestOp::GetNumClasses(int64_t *num_classes) { if (num_classes_ > 0) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h index b762a586cc3..81be096d09a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -164,10 +164,17 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @param show_all void Print(std::ostream &out, bool show_all) const override; -#ifdef ENABLE_PYTHON - static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, - int64_t *numClasses); + /// \brief Counts the total number of rows in Manifest + /// \param[in] file Dataset file path + /// \param[in] input_class_indexing Input map of class index + /// \param[in] usage Dataset usage + /// \param[out] count Number of rows counted + /// \param[out] numClasses Number of classes counted + /// \return Status of the function + static Status CountTotalRows(const std::string &file, const std::map &map, + const std::string &usage, int64_t *count, int64_t *numClasses); +#ifdef ENABLE_PYTHON // Get str-to-int mapping from label name to index static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, std::map *output_class_indexing); @@ -183,11 +190,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "ManifestOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - /// \brief Base-class override for GetNumClasses /// \param[out] num_classes the number of classes /// \return Status of the function diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index f696b9dbb8c..520d69c07b9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -474,22 +474,5 @@ Status MindRecordOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows = num_rows_; - if (num_rows_ <= 0) { - // The last operator is parent sampler - std::shared_ptr op = operators_.back(); - RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); - } - *dataset_size = num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h index cba0d9f6254..dae29f5541a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -212,11 +212,6 @@ class MindRecordOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "MindRecordOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index 45e99b6ac76..ee73e815063 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -471,19 +471,5 @@ Status MnistOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status MnistOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - num_rows = num_rows_; - if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows)); - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index 2accd8eb8c6..bf19bc4767b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -168,11 +168,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "MnistOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 99a67e355ba..d7801a62429 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -421,23 +421,5 @@ Status RandomDataOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows; - num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); - if (sampler_ != nullptr) { - int64_t sample_size; - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - } else { - *dataset_size = num_rows; - } - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h index b2ed68f5ad5..2b0980f3a7f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -203,11 +203,6 @@ class RandomDataOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "RandomDataOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: /** * The entry point code for when workers are launched diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 4a646aa0699..bd0d01f848a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -162,11 +162,11 @@ Status DistributedSamplerRT::ResetSampler() { } int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { - int64_t childs = num_rows; + int64_t child_num_rows = num_rows; if (!child_.empty()) { - childs = child_[0]->CalculateNumSamples(num_rows); + child_num_rows = child_[0]->CalculateNumSamples(num_rows); } - int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; + int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; return std::ceil(num_samples * 1.0 / num_devices_); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h index 288b7eeb553..8dbaada2365 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -63,6 +63,11 @@ class DistributedSamplerRT : public SamplerRT { int64_t GetDeviceNum() { return num_devices_; } + /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers + /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows, + /// then num_samples_ is not returned at all. + /// \param[in] num_rows The total number of rows in the dataset + /// \return int64_t Calculated number of samples int64_t CalculateNumSamples(int64_t num_rows) override; void Print(std::ostream &out, bool show_all) const override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 797383f131f..818e485ba60 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -520,19 +520,5 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status TextFileOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - sample_size = total_rows_; - if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - num_rows = num_rows_per_shard_; - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index 7084eae332b..3891ebc4159 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -198,11 +198,6 @@ class TextFileOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index ba7b123d9ef..ce89d176f5e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -1067,41 +1067,5 @@ Status TFReaderOp::PrepareNodePostAction() { return Status::OK(); } -// Get the file list of the specific shard ID -Status TFReaderOp::GetShardFileList(std::vector *shard_filenames) { - if (!shard_filenames->empty()) { - RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n"); - } - for (int index = 0; index < dataset_files_list_.size(); index++) { - if (index % num_devices_ == device_id_) { - shard_filenames->push_back(dataset_files_list_.at(index)); - } - } - return Status::OK(); -} - -// Get Dataset size -Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - num_rows = num_rows_; - if (num_rows_ <= 0) { - if (equal_rows_per_shard_) { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - num_rows = num_rows_per_shard_; - } else { - std::vector shard_file_list; - RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); - RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list)); - } - } - sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h index f2f79916d1b..08cdfd88f33 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -257,11 +257,6 @@ class TFReaderOp : public ParallelOp { // before providing their own implementations. Status PrepareNodePostAction() override; - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - static bool ValidateFirstRowCrc(const std::string &filename); private: @@ -400,11 +395,6 @@ class TFReaderOp : public ParallelOp { // @return - Status Status ComputeColMap() override; - // Private function for computing the file list of the specific shard ID. This is because in distributed scenario, - // data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case. - // @return - Status - the status code returned. - Status GetShardFileList(std::vector *shard_filenames); - int32_t device_id_; int32_t num_devices_; int64_t rows_per_buffer_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index 0f10b716440..b5db193f1d0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -447,16 +447,9 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) { return Status::OK(); } -#ifdef ENABLE_PYTHON Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t *count) { + const std::map &input_class_indexing, int64_t *count) { if (task_type == "Detection") { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - std::shared_ptr op; RETURN_IF_NOT_OK( Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op)); @@ -473,6 +466,7 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ return Status::OK(); } +#ifdef ENABLE_PYTHON Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict, std::map *output_class_indexing) { std::map input_class_indexing; @@ -516,36 +510,6 @@ Status VOCOp::ComputeColMap() { return Status::OK(); } -// Get Dataset size -Status VOCOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows = 0, sample_size; - if (image_ids_.size() == 0) { - if (task_type_ == TaskType::Detection) { - std::shared_ptr op; - RETURN_IF_NOT_OK( - Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - num_rows = static_cast(op->image_ids_.size()); - } else if (task_type_ == TaskType::Segmentation) { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - num_rows = static_cast(op->image_ids_.size()); - } - } else { - num_rows = image_ids_.size(); - } - sample_size = sampler_->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - Status VOCOp::GetClassIndexing(std::vector>> *output_class_indexing) { if ((*output_class_indexing).empty()) { if (task_type_ != TaskType::Detection) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h index 03256e2c67c..9d3d75410d9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -187,15 +187,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @param show_all void Print(std::ostream &out, bool show_all) const override; -#ifdef ENABLE_PYTHON // @param const std::string &dir - VOC dir path // @param const std::string &task_type - task type of reading voc job // @param const std::string &task_mode - task mode of reading voc job - // @param const py::dict &dict - input dict of class index + // @param const std::map input_class_indexing - input map of class index // @param int64_t *count - output rows number of VOCDataset static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t *count); + const std::map &input_class_indexing, int64_t *count); +#ifdef ENABLE_PYTHON // @param const std::string &dir - VOC dir path // @param const std::string &task_type - task type of reading voc job // @param const std::string &task_mode - task mode of reading voc job @@ -216,11 +216,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "VOCOp"; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - // /// \brief Gets the class indexing // /// \return Status - The status code return Status GetClassIndexing(std::vector>> *output_class_indexing) override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc index e5169ac891a..b64f3b3357c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc @@ -139,17 +139,5 @@ Status TakeOp::PreAccept(NodePass *p, bool *modified) { return p->PreRunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status TakeOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows; - RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); - *dataset_size = std::min(static_cast(max_takes_), num_rows); - dataset_size_ = *dataset_size; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h index 85a9213b526..08d5494d5f3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h @@ -94,11 +94,6 @@ class TakeOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return kTakeOp; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: int32_t max_takes_; // The number of takes that the user requested int32_t take_count_; // A counter for the current number of executed takes diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc index 8367b5b37af..d8822cef5ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc @@ -248,24 +248,6 @@ Status ZipOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } -// Get Dataset size -Status ZipOp::GetDatasetSize(int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - std::vector dataset_sizes; - int64_t child_dataset_size; - for (auto child : child_) { - RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size)); - dataset_sizes.push_back(child_dataset_size); - } - - *dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end()); - dataset_size_ = *dataset_size; - return Status::OK(); -} - Status ZipOp::ComputeColMap() { if (column_name_id_map_.empty()) { column_name_id_map_ = {}; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h index 2aace463ac9..f2cc2823997 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h @@ -120,11 +120,6 @@ class ZipOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return kZipOp; } - /// \brief Base-class override for GetDatasetSize - /// \param[out] dataset_size the size of the dataset - /// \return Status of the function - Status GetDatasetSize(int64_t *dataset_size) override; - private: // Handles preprocessing of the main loop, used when starting new epoch Status prepare(TensorQTable *const table); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc index ee67b21ca75..3c0a4e6407e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc @@ -114,5 +114,33 @@ std::vector> BatchNode::Build() { return node_ops; } +// Get Dataset size +Status BatchNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } +#ifdef ENABLE_PYTHON + if (batch_size_func_) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); + dataset_size_ = *dataset_size; + return Status::OK(); + } +#endif + int64_t num_rows; + RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); + if (num_rows > 0 && batch_size_ > 0) { + if (drop_remainder_) { + num_rows = static_cast(floor(num_rows / (1.0 * batch_size_))); + } else { + num_rows = static_cast(ceil(num_rows / (1.0 * batch_size_))); + } + } + *dataset_size = num_rows; + dataset_size_ = num_rows; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h index 9bb1802a4e7..e369a32121e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h @@ -64,6 +64,15 @@ class BatchNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: int32_t batch_size_; bool drop_remainder_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc index 1cdb6cbd262..8c52bd2a7b5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc @@ -127,6 +127,5 @@ Status BucketBatchByLengthNode::ValidateParams() { return Status::OK(); } - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h index 1cdf46cd6ea..2805c54fdd5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h @@ -60,6 +60,8 @@ class BucketBatchByLengthNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + bool IsSizeDefined() override { return false; }; + private: std::vector column_names_; std::vector bucket_boundaries_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h index c542e46e2b5..85246e75000 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h @@ -58,6 +58,8 @@ class ConcatNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + bool IsSizeDefined() override { return false; } + private: std::shared_ptr sampler_; std::vector> children_flag_and_nums_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index e92b3b3642c..993c987660e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -342,6 +342,31 @@ Status DatasetNode::GetShardId(int32_t *shard_id) { RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); } } + +// Gets the dataset size +Status DatasetNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + if (!IsSizeDefined()) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); + dataset_size_ = *dataset_size; + return Status::OK(); + } + if (children_.size() == 1) { + return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size); + } else if (children_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in front of the child_ structure, so we get the dataset size from the last child. + return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size); + } else { + RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); + } +} + // Visitor accepting method for NodePass Status SourceNode::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 57ed1de1717..11c432228c0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -25,6 +25,7 @@ #include #include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/engine/consumers/tree_consumer.h" namespace mindspore { namespace dataset { @@ -32,6 +33,7 @@ namespace dataset { class Dataset; class SamplerObj; class NodePass; +class DatasetSizeGetter; #define RETURN_EMPTY_IF_ERROR(_s) \ do { \ @@ -169,6 +171,14 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Status Status::OK() if get shard id successfully virtual Status GetShardId(int32_t *shard_id); + /// \brief Gets the dataset size + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \return Status - The status code return + virtual Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size); + /// \brief Getter function for child nodes /// \return Child nodes const std::vector> Children() const { return children_; } @@ -219,10 +229,13 @@ class DatasetNode : public std::enable_shared_from_this { /// \notes Remove me after changing return val of Build() Status BuildStatus() { return build_status; } + virtual bool IsSizeDefined() { return true; } + protected: std::vector> children_; DatasetNode *parent_; std::shared_ptr cache_; + int64_t dataset_size_ = -1; int32_t num_workers_; int32_t rows_per_buffer_; int32_t connector_que_size_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h index 7e66168c61c..30a98588d95 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h @@ -55,6 +55,8 @@ class FilterNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + bool IsSizeDefined() override { return false; }; + /// \brief Base-class override for accepting NodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc index 65ef8bacea8..cf21c36a335 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc @@ -56,6 +56,23 @@ Status RepeatNode::ValidateParams() { return Status::OK(); } +// Get Dataset size +Status RepeatNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); + if (num_rows > 0 && repeat_count_ > 0) { + num_rows = num_rows * repeat_count_; + } + *dataset_size = num_rows; + dataset_size_ = num_rows; + return Status::OK(); +} + // Visitor accepting method for NodePass Status RepeatNode::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h index 318a7dda3e3..200ad3d17f6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h @@ -56,6 +56,15 @@ class RepeatNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + /// \brief Base-class override for accepting NodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index b2e4eae2523..3ee509ec979 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -56,5 +56,22 @@ Status SkipNode::ValidateParams() { return Status::OK(); } +// Get Dataset size +Status SkipNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); + *dataset_size = 0; + if (skip_count_ >= 0 && skip_count_ < num_rows) { + *dataset_size = num_rows - skip_count_; + } + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h index 376ea968694..f1d154b6b3f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h @@ -54,6 +54,15 @@ class SkipNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: int32_t skip_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 9b888eb544b..502497e14bb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/util/status.h" @@ -87,5 +88,66 @@ Status CelebANode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status CelebANode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + int64_t num_rows, sample_size; + std::ifstream partition_file; + std::string line; + Path folder_path(dataset_dir_); + std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); + if (!attr_file.is_open()) { + std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString(); + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name); + } + + std::string rows_num; + (void)getline(attr_file, rows_num); + try { + num_rows = static_cast(std::stoul(rows_num)); // First line is rows number in attr file + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num); + } + if (usage_ != "all") { + int64_t partition_num = 0; + char usage_type; + if (usage_ == "train") { + usage_type = '0'; + } else { + if (usage_ == "valid") { + usage_type = '1'; + } else { + if (usage_ == "test") + usage_type = '2'; + else + RETURN_STATUS_UNEXPECTED("Invalid usage."); + } + } + if (!partition_file.is_open()) { + partition_file.open((folder_path / "list_eval_partition.txt").toString()); + } + if (partition_file.is_open()) { + while (getline(partition_file, line)) { + int start = line.find(' '); + if (line.at(start + 1) == usage_type) { + partition_num++; + } + } + } else { + std::string partition_file_name = "list_eval_partition.txt"; + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open CelebA partition file: " + partition_file_name); + } + num_rows = std::min(num_rows, partition_num); + } + + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h index a2518d52205..fbe2060ad1f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h @@ -61,6 +61,15 @@ class CelebANode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index f24cc21bfab..c48ebf4f760 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -83,5 +83,20 @@ Status Cifar100Node::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status Cifar100Node::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h index 25e63e2b916..af9b3ce8394 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h @@ -59,6 +59,15 @@ class Cifar100Node : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 0b3eec3cdfd..2127a7ab417 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -81,5 +81,20 @@ Status Cifar10Node::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status Cifar10Node::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h index 63140f4b673..102bc17c364 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h @@ -59,6 +59,15 @@ class Cifar10Node : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index f1e8ad82a92..32a38e8ddea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -241,5 +241,21 @@ Status CLUENode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status CLUENode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(dataset_files_, &num_rows)); + sample_size = num_samples_; + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index c315ab8386e..c149f86e93d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -61,6 +61,15 @@ class CLUENode : public NonMappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: /// \brief Split string based on a character delimiter /// \return A string vector diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index 40ce5e808dd..c893fe2a78f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -134,5 +134,20 @@ Status CocoNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status CocoNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows = 0, sample_size; + RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h index d4c3db57ab4..d0576701534 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h @@ -59,6 +59,15 @@ class CocoNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_dir_; std::string annotation_file_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index 8aacf0de429..d198bb1e240 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -153,5 +153,21 @@ Status CSVNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status CSVNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows)); + sample_size = num_samples_; + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h index 88cb488d089..a4619dbfa85 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h @@ -82,6 +82,15 @@ class CSVNode : public NonMappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::vector dataset_files_; char field_delim_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index b9918183333..1eb44a3dbd7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -93,6 +93,5 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) { *shard_id = 0; return Status::OK(); } - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index d06dedc25c9..ec25383492d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -64,6 +64,13 @@ class GeneratorNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Setter for DatasetSize in GeneratorNode + /// \param[in] sz dataset size to set + /// \return void + void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; } + + bool IsSizeDefined() override { return false; } + private: py::function generator_function_; std::vector column_names_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index 2dc8fa57b77..2c8c0309dcd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -89,5 +89,20 @@ Status ImageFolderNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status ImageFolderNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t sample_size, num_rows; + RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h index a112407d3b4..cd17b4f3133 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h @@ -65,6 +65,15 @@ class ImageFolderNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_dir_; bool decode_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index b9fd2ff9f6d..044129bfdcc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -111,5 +111,21 @@ Status ManifestNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status ManifestNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + int64_t num_classes; // dummy variable + RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h index eb11bb1003c..361a4b699f0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h @@ -60,6 +60,15 @@ class ManifestNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_file_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 5662b80d0cd..9ce036d4827 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -152,7 +152,6 @@ std::vector> MindDataNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; - std::vector> operators_; build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_); RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() @@ -184,5 +183,28 @@ Status MindDataNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status MindDataNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + std::vector> operators; + RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators, num_padded_)); + + if (search_for_pattern_) { + dataset_files_ = {dataset_file_}; + } + + // The last operator is parent sampler + std::shared_ptr op = operators.back(); + RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_)); + *dataset_size = num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h index 4078125a071..421afce5810 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h @@ -74,6 +74,15 @@ class MindDataNode : public MappableSourceNode { /// \note Pybind will use this function to set sample_bytes into MindDataNode void SetSampleBytes(std::map *sample_bytes); + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_file_; // search_for_pattern_ will be true in this mode std::vector dataset_files_; // search_for_pattern_ will be false in this mode @@ -83,6 +92,7 @@ class MindDataNode : public MappableSourceNode { nlohmann::json padded_sample_; std::map sample_bytes_; // enable in python int64_t num_padded_; + std::vector> operators_; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 57371b17cae..40764e7c418 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -75,5 +75,20 @@ Status MnistNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status MnistNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h index 93868496311..4e56c1010b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h @@ -59,6 +59,15 @@ class MnistNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 5e1f75ee7ca..7c6abb3c28f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -86,17 +86,16 @@ std::vector> RandomNode::Build() { schema_file_path = schema_path_; } - std::unique_ptr data_schema; std::vector columns_to_load; if (columns_list_.size() > 0) { columns_to_load = columns_list_; } if (!schema_file_path.empty() || !schema_json_string.empty()) { - data_schema = std::make_unique(); + data_schema_ = std::make_unique(); if (!schema_file_path.empty()) { - data_schema->LoadSchemaFile(schema_file_path, columns_to_load); + data_schema_->LoadSchemaFile(schema_file_path, columns_to_load); } else if (!schema_json_string.empty()) { - data_schema->LoadSchemaString(schema_json_string, columns_to_load); + data_schema_->LoadSchemaString(schema_json_string, columns_to_load); } } @@ -109,7 +108,7 @@ std::vector> RandomNode::Build() { std::shared_ptr op; op = std::make_shared(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, - std::move(data_schema), std::move(sampler_->Build())); + std::move(data_schema_), std::move(sampler_->Build())); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() RETURN_EMPTY_IF_ERROR(build_status); @@ -125,5 +124,24 @@ Status RandomNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status RandomNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); + if (sampler_ != nullptr) { + int64_t sample_size; + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + } else { + *dataset_size = num_rows; + } + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index ea1e1f6346f..33ef9c0e9bf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -79,6 +79,15 @@ class RandomNode : public NonMappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: /// \brief A quick inline for producing a random number between (and including) min/max /// \param[in] min minimum number that can be generated. @@ -92,6 +101,7 @@ class RandomNode : public NonMappableSourceNode { std::vector columns_list_; std::shared_ptr sampler_; std::mt19937 rand_gen_; + std::unique_ptr data_schema_; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index f5745fc29cb..69145ff4e87 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -122,5 +122,20 @@ Status TextFileNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status TextFileNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size = num_samples_; + RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(dataset_files_, &num_rows)); + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h index cdad8eaadcc..c4872556e38 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h @@ -61,6 +61,15 @@ class TextFileNode : public NonMappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::vector dataset_files_; int32_t num_samples_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index a7a8a176581..b0aef35e1eb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -169,5 +169,41 @@ Status TFRecordNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status TFRecordNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + if (!shard_equal_rows_) { + // Data will be sharded by file + std::vector shard_file_list; + RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); + RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, 8, estimate)); + } else { + // Data will be sharded by row + RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, 8, estimate)); + num_rows = static_cast(ceil(num_rows / (num_shards_ * 1.0))); + } + *dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +// Get the file list of the specific shard ID +Status TFRecordNode::GetShardFileList(std::vector *shard_filenames) { + if (!shard_filenames->empty()) { + RETURN_STATUS_UNEXPECTED("The initial file list must be empty."); + } + for (int index = 0; index < dataset_files_.size(); index++) { + if (index % num_shards_ == shard_id_) { + shard_filenames->push_back(dataset_files_.at(index)); + } + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h index 2c2a14be702..6aa4d37cf5e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h @@ -88,6 +88,20 @@ class TFRecordNode : public NonMappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Get the file list of the specific shard ID + /// \param[out] shard_filenames the list of filenames for that specific shard ID + /// \return Status of the function + Status GetShardFileList(std::vector *shard_filenames); + private: std::vector dataset_files_; std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 9b36424d7c7..8cf9ece43ea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -128,5 +128,20 @@ Status VOCNode::GetShardId(int32_t *shard_id) { return Status::OK(); } +// Get Dataset size +Status VOCNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows = 0, sample_size; + RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); + sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h index 55750fa40f5..777ac2e3766 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h @@ -61,6 +61,15 @@ class VOCNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: const std::string kColumnImage = "image"; const std::string kColumnTarget = "target"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index 509b4a7bfdd..4a11b448c4e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "minddata/dataset/engine/datasetops/take_op.h" #include "minddata/dataset/util/status.h" @@ -56,5 +57,19 @@ Status TakeNode::ValidateParams() { return Status::OK(); } +// Get Dataset size +Status TakeNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); + *dataset_size = std::min(static_cast(take_count_), num_rows); + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h index 6d3b1e3d8a8..012a647dcf9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h @@ -54,6 +54,15 @@ class TakeNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: int32_t take_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 9cf11421e1c..c66a973eafb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/opt/pass.h" @@ -62,6 +63,25 @@ std::vector> ZipNode::Build() { return node_ops; } +// Get Dataset size +Status ZipNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + std::vector dataset_sizes; + int64_t child_dataset_size; + for (auto child : children_) { + RETURN_IF_NOT_OK(child->GetDatasetSize(size_getter, estimate, &child_dataset_size)); + dataset_sizes.push_back(child_dataset_size); + } + + *dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end()); + dataset_size_ = *dataset_size; + return Status::OK(); +} + // Visitor accepting method for NodePass Status ZipNode::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h index 86ec73e65c8..855d30193b2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h @@ -54,6 +54,17 @@ class ZipNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + private: + std::vector> datasets_; /// \brief Base-class override for accepting NodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index d974671a7e4..47988fb80ae 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -100,8 +100,10 @@ class Dataset : public std::enable_shared_from_this { ~Dataset() = default; /// \brief Gets the dataset size + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. /// \return dataset size. If failed, return -1 - int64_t GetDatasetSize(); + int64_t GetDatasetSize(bool estimate = false); /// \brief Gets the output type /// \return a vector of DataType. If failed, return an empty vector diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index a6267b41917..ab8ea415b8b 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1370,6 +1370,19 @@ class Dataset: runtime_context.AssignConsumer(getter) return getter, runtime_context, api_tree + def _init_size_getter(self): + """ + Get pipeline information. + """ + ir_tree, api_tree = self.create_ir_tree() + + runtime_context = cde.PythonRuntimeContext() + runtime_context.Init() + getter = cde.DatasetSizeGetters() + getter.Init(ir_tree) + runtime_context.AssignConsumer(getter) + return getter, runtime_context, api_tree + def get_col_names(self): """ Get names of the columns in the dataset @@ -1413,8 +1426,8 @@ class Dataset: Number, number of batches. """ if self.dataset_size is None: - runtime_getter = self._init_tree_getters() - self.dataset_size = runtime_getter[0].GetDatasetSize() + runtime_getter = self._init_size_getter() + self.dataset_size = runtime_getter[0].GetDatasetSize(False) return self.dataset_size def num_classes(self): @@ -2783,8 +2796,8 @@ class TransferDataset(Dataset): new_op.num_parallel_workers = self.num_parallel_workers new_op.queue_name = self.queue_name new_op.device_type = self.device_type - new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212 - new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212 + new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212 + new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212 return new_op @@ -3737,13 +3750,25 @@ class GeneratorDataset(MappableDataset): return self.sampler.is_sharded() def parse(self, children=None): + dataset_size = -1 + if hasattr(self.source, "__len__"): + if not self.num_shards: + dataset_size = len(self.source) + else: + dataset_size = math.ceil(len(self.source) / self.num_shards) + + rows_from_sampler = self._get_sampler_dataset_size() + if rows_from_sampler is not None and rows_from_sampler < dataset_size: + dataset_size = rows_from_sampler if self.schema is None: - return cde.GeneratorNode(self.source, self.column_names, self.column_types) \ + return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize( + dataset_size) \ .SetNumWorkers(self.num_parallel_workers) schema = self.schema if isinstance(schema, Schema): schema = self.schema.cpp_schema - return cde.GeneratorNode(self.source, schema).SetNumWorkers(self.num_parallel_workers) + return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(dataset_size).SetNumWorkers( + self.num_parallel_workers) class TFRecordDataset(SourceDataset): diff --git a/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc index 79cec15936b..54c5097f642 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc @@ -341,7 +341,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess8) { EXPECT_EQ(shapes.size(), 2); EXPECT_EQ(shapes[0].ToString(), "<>"); EXPECT_EQ(shapes[1].ToString(), "<>"); - EXPECT_EQ(ds->GetDatasetSize(), 9); + EXPECT_EQ(ds->GetDatasetSize(), 5); EXPECT_EQ(ds->GetRepeatCount(), 1); EXPECT_EQ(ds->GetColumnNames(), column_names); diff --git a/tests/ut/python/dataset/test_datasets_get_dataset_size.py b/tests/ut/python/dataset/test_datasets_get_dataset_size.py index 26ee236ed9b..f0d3503cfb4 100644 --- a/tests/ut/python/dataset/test_datasets_get_dataset_size.py +++ b/tests/ut/python/dataset/test_datasets_get_dataset_size.py @@ -54,22 +54,26 @@ def test_imagenet_tf_file_dataset_size(): ds_total = ds.TFRecordDataset(IMAGENET_TFFILE_DIR) assert ds_total.get_dataset_size() == 12 - ds_shard_1_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=1, shard_id=0) + ds_shard_1_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=1, shard_id=0, shard_equal_rows=True) assert ds_shard_1_0.get_dataset_size() == 12 - ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0) + ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0, shard_equal_rows=True) assert ds_shard_2_0.get_dataset_size() == 6 - # FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code. - # Correct answer should be 12/3=4, the code issue should be addressed. - ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0) - assert ds_shard_3_0.get_dataset_size() == 6 + ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0, shard_equal_rows=True) + assert ds_shard_3_0.get_dataset_size() == 4 count = 0 for _ in ds_shard_3_0.create_dict_iterator(): count += 1 assert ds_shard_3_0.get_dataset_size() == count + # shard_equal_rows is set to False therefore, get_dataset_size must return count + ds_shard_4_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=4, shard_id=0) + count = 0 + for _ in ds_shard_4_0.create_dict_iterator(): + count += 1 + assert ds_shard_4_0.get_dataset_size() == count def test_mnist_dataset_size(): ds_total = ds.MnistDataset(MNIST_DATA_DIR)