!9249 Optimizing GetDatasetSize

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-02 08:26:34 +08:00 committed by Gitee
commit df44e1339e
106 changed files with 798 additions and 647 deletions

View File

@ -199,12 +199,13 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
// Constructor
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() {
int64_t Dataset::GetDatasetSize(bool estimate) {
int64_t dataset_size;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
RETURN_SECOND_IF_ERROR(size_getter->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(size_getter->GetDatasetSize(&dataset_size, estimate), -1);
return dataset_size;
}

View File

@ -106,19 +106,7 @@ PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) {
}));
PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) {
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
[](const std::string &file, const py::dict &dict, const std::string &usage) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
return py::make_tuple(count, num_classes);
})
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict,
const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
return output_class_indexing;
});
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp");
}));
PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
@ -173,13 +161,6 @@ PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) {
PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) {
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict) {
std::map<std::string, int32_t> output_class_indexing;

View File

@ -184,7 +184,11 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}));
}))
.def("SetGeneratorDatasetSize", [](std::shared_ptr<GeneratorNode> self, int64_t sz) {
self->SetGeneratorDatasetSize(sz);
return self;
});
}));
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {

View File

@ -93,12 +93,6 @@ PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
return output_class_indexing;
})
.def("GetDatasetSize",
[](PythonTreeGetters &self) {
int64_t dataset_size;
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
return dataset_size;
})
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
}));
@ -164,5 +158,18 @@ PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
}));
PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) {
(void)py::class_<PythonDatasetSizeGetter, TreeConsumer, std::shared_ptr<PythonDatasetSizeGetter>>(
*m, "DatasetSizeGetters")
.def(py::init<>())
.def("Init", [](PythonDatasetSizeGetter &self,
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("GetDatasetSize", [](PythonDatasetSizeGetter &self, bool estimate) {
int64_t size;
THROW_IF_ERROR(self.GetDatasetSize(&size, estimate));
return size;
});
}));
} // namespace dataset
} // namespace mindspore

View File

@ -65,4 +65,8 @@ Status PythonTreeGetters::GetRow(TensorRow *r) {
py::gil_scoped_release gil_release;
return TreeGetters::GetRow(r);
}
Status PythonDatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) {
py::gil_scoped_release gil_release;
return DatasetSizeGetter::GetRow(tree_adapter, r);
}
} // namespace mindspore::dataset

View File

@ -60,5 +60,9 @@ class PythonTreeGetters : public TreeGetters {
public:
Status GetRow(TensorRow *r) override;
};
class PythonDatasetSizeGetter : public DatasetSizeGetter {
public:
Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) override;
};
} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_

View File

@ -451,29 +451,6 @@ Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) {
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
if (*dataset_size == -1) { // run through the tree and get everything
TensorRow row;
RETURN_IF_NOT_OK(GetRow(&row));
int64_t row_cnt = 0;
while (!row.empty()) {
++row_cnt;
RETURN_IF_NOT_OK(GetRow(&row));
}
*dataset_size = row_cnt;
}
dataset_size_ = *dataset_size; // save the previous result
}
*dataset_size = dataset_size_;
return Status::OK();
}
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
*types = first_row_type_;
@ -573,5 +550,46 @@ Status BuildVocabConsumer::Start() {
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
return Status::OK();
}
Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
if (dataset_size_ == -1) {
RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
dataset_size_ = *size; // save the previous result
}
*size = dataset_size_;
return Status::OK();
}
Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
root_ = std::move(d);
return Status::OK();
}
Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>();
tree_adapters_.push_back(tree_adapter);
tree_adapter->SetPrePassOverride([](OptPass pre) {
pre.push_back(
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
return pre;
});
RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1));
TensorRow row;
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
int64_t row_cnt = 0;
while (!row.empty()) {
++row_cnt;
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
}
*dataset_size = row_cnt;
return Status::OK();
}
Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
return tree_adapter->GetNext(row);
}
Status DatasetSizeGetter::Terminate() {
for (const auto &tree : tree_adapters_) {
RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -177,7 +177,6 @@ class TreeGetters : public TreeConsumer {
~TreeGetters() = default;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
Status GetBatchSize(int64_t *batch_size);
@ -186,7 +185,7 @@ class TreeGetters : public TreeConsumer {
Status GetColumnNames(std::vector<std::string> *output);
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
std::string Name() override { return "TreeGetters"; }
virtual Status GetRow(TensorRow *r);
virtual Status GetRow(TensorRow *row);
private:
Status GetFirstRowShapeAndType();
@ -202,6 +201,35 @@ class TreeGetters : public TreeConsumer {
Status InternalInit();
};
/// Consumer that is used to get some pipeline information
class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this<DatasetSizeGetter> {
public:
DatasetSizeGetter() : dataset_size_(-1) {}
~DatasetSizeGetter() = default;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status Terminate() override;
/// \brief Function to get the dataset size
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *size, bool estimate = false);
virtual Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row);
std::string Name() override { return "DatasetSizeGetter"; }
/// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node
/// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate
/// \return Status - The status code return
Status DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size);
private:
std::shared_ptr<DatasetNode> root_;
std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_;
int64_t dataset_size_;
};
class BuildVocabConsumer : public TreeConsumer {
public:
/// BuildVocabConsumer Constructor which will call the base class default constructor.

View File

@ -537,30 +537,6 @@ Status BatchOp::ComputeColMap() {
return Status::OK();
}
Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
#ifdef ENABLE_PYTHON
if (batch_size_func_) {
*dataset_size = -1;
return Status::OK();
}
#endif
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && start_batch_size_ > 0) {
if (drop_) {
num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * start_batch_size_)));
} else {
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * start_batch_size_)));
}
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
int64_t BatchOp::GetTreeBatchSize() {
#ifdef ENABLE_PYTHON
if (batch_size_func_) {

View File

@ -225,11 +225,6 @@ class BatchOp : public ParallelOp {
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
int64_t GetTreeBatchSize() override;
protected:

View File

@ -232,12 +232,5 @@ Status BucketBatchByLengthOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -112,11 +112,6 @@ class BucketBatchByLengthOp : public PipelineOp {
std::string Name() const override { return kBucketBatchByLengthOp; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded

View File

@ -196,12 +196,5 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
}
// Get Dataset size
Status ConcatOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -111,11 +111,6 @@ class ConcatOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);

View File

@ -294,24 +294,6 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
return Status::OK();
}
// Gets the dataset size
Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
if (child_.size() == 1) {
return child_[0]->GetDatasetSize(dataset_size);
} else if (child_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child.
return child_[child_.size() - 1]->GetDatasetSize(dataset_size);
} else {
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}
// Gets the number of classes
Status DatasetOp::GetNumClasses(int64_t *num_classes) {
if (child_.size() == 1) {

View File

@ -180,10 +180,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The error code return
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
/// \brief Gets the dataset size
/// \return Status - The status code return
virtual Status GetDatasetSize(int64_t *dataset_size);
/// \brief Gets the batch size
/// \return Status - The status code return
virtual int64_t GetTreeBatchSize();

View File

@ -258,13 +258,5 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
}
// Get Dataset size
Status FilterOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -137,11 +137,6 @@ class FilterOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return kFilterOp; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// predicate_func python callable which returns a boolean value.
std::shared_ptr<TensorOp> predicate_func_;

View File

@ -187,21 +187,6 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
}
// Get Dataset size
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && num_repeats_ > 0) {
num_rows = num_rows * num_repeats_;
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; }
} // namespace dataset
} // namespace mindspore

View File

@ -133,11 +133,6 @@ class RepeatOp : public PipelineOp {
/// \@return Status - The error code return
Status Reset() override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
int64_t GetTreeRepeatCount() override;
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes

View File

@ -136,20 +136,5 @@ Status SkipOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<SkipOp>(), modified);
}
// Get Dataset size
Status SkipOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
*dataset_size = 0;
if (max_skips_ >= 0 && max_skips_ < num_rows) {
*dataset_size = num_rows - max_skips_;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -86,11 +86,6 @@ class SkipOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return kSkipOp; }

View File

@ -452,63 +452,5 @@ Status CelebAOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status CelebAOp::GetDatasetSize(int64_t *dataset_size) {
int64_t num_rows, sample_size;
std::string line;
Path folder_path(folder_path_);
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
if (!attr_file.is_open()) {
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
}
std::string rows_num;
(void)getline(attr_file, rows_num);
try {
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
}
if (usage_ != "all") {
int64_t partition_num = 0;
char usage_type;
if (usage_ == "train") {
usage_type = '0';
} else {
if (usage_ == "valid") {
usage_type = '1';
} else {
if (usage_ == "test")
usage_type = '2';
else
RETURN_STATUS_UNEXPECTED("Invalid usage.");
}
}
if (!partition_file_.is_open()) {
partition_file_.open((folder_path / "list_eval_partition.txt").toString());
}
if (partition_file_.is_open()) {
while (getline(partition_file_, line)) {
int start = line.find(' ');
if (line.at(start + 1) == usage_type) {
partition_num++;
}
}
} else {
std::string partition_file_name = "list_eval_partition.txt";
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name);
}
num_rows = std::min(num_rows, partition_num);
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -179,11 +179,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CelebAOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Called first when function is called
// @return

View File

@ -508,20 +508,5 @@ Status CifarOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status CifarOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0)
RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows));
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -175,11 +175,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CifarOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -565,19 +565,5 @@ Status ClueOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
}
// Get Dataset size
Status ClueOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
sample_size = num_samples_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -197,11 +197,6 @@ class ClueOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -681,39 +681,6 @@ Status CocoOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
std::string task_type;
switch (task_type_) {
case TaskType::Detection:
task_type = "Detection";
break;
case TaskType::Keypoint:
task_type = "Keypoint";
break;
case TaskType::Panoptic:
task_type = "Panoptic";
break;
case TaskType::Stuff:
task_type = "Stuff";
break;
}
if (image_ids_.size() == 0) {
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
} else {
num_rows = image_ids_.size();
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) {

View File

@ -213,11 +213,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CocoOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
/// \brief Gets the class indexing
/// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;

View File

@ -916,19 +916,5 @@ Status CsvOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
}
// Get Dataset size
Status CsvOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
sample_size = num_samples_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -318,11 +318,6 @@ class CsvOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -274,11 +274,5 @@ Status GeneratorOp::ComputeColMap() {
}
return Status::OK();
}
Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -136,8 +136,6 @@ class GeneratorOp : public PipelineOp {
Status Init();
Status GetDatasetSize(int64_t *dataset_size) override;
private:
py::function generator_function_;
std::vector<std::string> column_names_;

View File

@ -465,24 +465,6 @@ Status ImageFolderOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
num_rows = num_rows_;
if (num_rows_ <= 0) {
// GetDatasetSize will not be impacted by class_index_
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {}));
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
// Get number of classes
Status ImageFolderOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {

View File

@ -217,11 +217,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "ImageFolderOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function

View File

@ -396,16 +396,9 @@ Status ManifestOp::CountDatasetInfo() {
return Status::OK();
}
#ifdef ENABLE_PYTHON
Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage,
int64_t *count, int64_t *numClasses) {
Status ManifestOp::CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map,
const std::string &usage, int64_t *count, int64_t *numClasses) {
// the logic of counting the number of samples is copied from ParseManifestFile()
std::map<std::string, int32_t> map;
for (auto p : dict) {
(void)map.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
py::reinterpret_borrow<py::int_>(p.second)));
}
std::shared_ptr<ManifestOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op));
@ -415,6 +408,7 @@ Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict,
return Status::OK();
}
#ifdef ENABLE_PYTHON
Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
@ -459,23 +453,6 @@ Status ManifestOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status ManifestOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
num_rows = static_cast<int64_t>(op->image_labelname_.size());
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
// Get number of classes
Status ManifestOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {

View File

@ -164,10 +164,17 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
#ifdef ENABLE_PYTHON
static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count,
int64_t *numClasses);
/// \brief Counts the total number of rows in Manifest
/// \param[in] file Dataset file path
/// \param[in] input_class_indexing Input map of class index
/// \param[in] usage Dataset usage
/// \param[out] count Number of rows counted
/// \param[out] numClasses Number of classes counted
/// \return Status of the function
static Status CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map,
const std::string &usage, int64_t *count, int64_t *numClasses);
#ifdef ENABLE_PYTHON
// Get str-to-int mapping from label name to index
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing);
@ -183,11 +190,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "ManifestOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function

View File

@ -474,22 +474,5 @@ Status MindRecordOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = num_rows_;
if (num_rows_ <= 0) {
// The last operator is parent sampler
std::shared_ptr<ShardOperator> op = operators_.back();
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
}
*dataset_size = num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -212,11 +212,6 @@ class MindRecordOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "MindRecordOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);

View File

@ -471,19 +471,5 @@ Status MnistOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status MnistOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows));
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -168,11 +168,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "MnistOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -421,23 +421,5 @@ Status RandomDataOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
if (sampler_ != nullptr) {
int64_t sample_size;
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
} else {
*dataset_size = num_rows;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -203,11 +203,6 @@ class RandomDataOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "RandomDataOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
/**
* The entry point code for when workers are launched

View File

@ -162,11 +162,11 @@ Status DistributedSamplerRT::ResetSampler() {
}
int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t childs = num_rows;
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
childs = child_[0]->CalculateNumSamples(num_rows);
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
}
int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs;
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
return std::ceil(num_samples * 1.0 / num_devices_);
}

View File

@ -63,6 +63,11 @@ class DistributedSamplerRT : public SamplerRT {
int64_t GetDeviceNum() { return num_devices_; }
/// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers
/// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows,
/// then num_samples_ is not returned at all.
/// \param[in] num_rows The total number of rows in the dataset
/// \return int64_t Calculated number of samples
int64_t CalculateNumSamples(int64_t num_rows) override;
void Print(std::ostream &out, bool show_all) const override;

View File

@ -520,19 +520,5 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
}
// Get Dataset size
Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
sample_size = total_rows_;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -198,11 +198,6 @@ class TextFileOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -1067,41 +1067,5 @@ Status TFReaderOp::PrepareNodePostAction() {
return Status::OK();
}
// Get the file list of the specific shard ID
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
if (!shard_filenames->empty()) {
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
}
for (int index = 0; index < dataset_files_list_.size(); index++) {
if (index % num_devices_ == device_id_) {
shard_filenames->push_back(dataset_files_list_.at(index));
}
}
return Status::OK();
}
// Get Dataset size
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0) {
if (equal_rows_per_shard_) {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = num_rows_per_shard_;
} else {
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
}
}
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -257,11 +257,6 @@ class TFReaderOp : public ParallelOp {
// before providing their own implementations.
Status PrepareNodePostAction() override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
static bool ValidateFirstRowCrc(const std::string &filename);
private:
@ -400,11 +395,6 @@ class TFReaderOp : public ParallelOp {
// @return - Status
Status ComputeColMap() override;
// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
// @return - Status - the status code returned.
Status GetShardFileList(std::vector<std::string> *shard_filenames);
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;

View File

@ -447,16 +447,9 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) {
return Status::OK();
}
#ifdef ENABLE_PYTHON
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t *count) {
const std::map<std::string, int32_t> &input_class_indexing, int64_t *count) {
if (task_type == "Detection") {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
py::reinterpret_borrow<py::int_>(p.second)));
}
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
@ -473,6 +466,7 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
return Status::OK();
}
#ifdef ENABLE_PYTHON
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
@ -516,36 +510,6 @@ Status VOCOp::ComputeColMap() {
return Status::OK();
}
// Get Dataset size
Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
if (image_ids_.size() == 0) {
if (task_type_ == TaskType::Detection) {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
num_rows = static_cast<int64_t>(op->image_ids_.size());
} else if (task_type_ == TaskType::Segmentation) {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
num_rows = static_cast<int64_t>(op->image_ids_.size());
}
} else {
num_rows = image_ids_.size();
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status VOCOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
if (task_type_ != TaskType::Detection) {

View File

@ -187,15 +187,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
#ifdef ENABLE_PYTHON
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param const std::map<std::string, int32_t> input_class_indexing - input map of class index
// @param int64_t *count - output rows number of VOCDataset
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t *count);
const std::map<std::string, int32_t> &input_class_indexing, int64_t *count);
#ifdef ENABLE_PYTHON
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
@ -216,11 +216,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "VOCOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// /// \brief Gets the class indexing
// /// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;

View File

@ -139,17 +139,5 @@ Status TakeOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<TakeOp>(), modified);
}
// Get Dataset size
Status TakeOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
*dataset_size = std::min(static_cast<int64_t>(max_takes_), num_rows);
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -94,11 +94,6 @@ class TakeOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kTakeOp; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes

View File

@ -248,24 +248,6 @@ Status ZipOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<ZipOp>(), modified);
}
// Get Dataset size
Status ZipOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
std::vector<int32_t> dataset_sizes;
int64_t child_dataset_size;
for (auto child : child_) {
RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size));
dataset_sizes.push_back(child_dataset_size);
}
*dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end());
dataset_size_ = *dataset_size;
return Status::OK();
}
Status ZipOp::ComputeColMap() {
if (column_name_id_map_.empty()) {
column_name_id_map_ = {};

View File

@ -120,11 +120,6 @@ class ZipOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kZipOp; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Handles preprocessing of the main loop, used when starting new epoch
Status prepare(TensorQTable *const table);

View File

@ -114,5 +114,33 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
return node_ops;
}
// Get Dataset size
Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
#ifdef ENABLE_PYTHON
if (batch_size_func_) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
dataset_size_ = *dataset_size;
return Status::OK();
}
#endif
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
if (num_rows > 0 && batch_size_ > 0) {
if (drop_remainder_) {
num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * batch_size_)));
} else {
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * batch_size_)));
}
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -64,6 +64,15 @@ class BatchNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
int32_t batch_size_;
bool drop_remainder_;

View File

@ -127,6 +127,5 @@ Status BucketBatchByLengthNode::ValidateParams() {
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -60,6 +60,8 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
bool IsSizeDefined() override { return false; };
private:
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;

View File

@ -58,6 +58,8 @@ class ConcatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
bool IsSizeDefined() override { return false; }
private:
std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;

View File

@ -342,6 +342,31 @@ Status DatasetNode::GetShardId(int32_t *shard_id) {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
}
}
// Gets the dataset size
Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
if (!IsSizeDefined()) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
dataset_size_ = *dataset_size;
return Status::OK();
}
if (children_.size() == 1) {
return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size);
} else if (children_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child.
return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size);
} else {
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}
// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor

View File

@ -25,6 +25,7 @@
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
namespace mindspore {
namespace dataset {
@ -32,6 +33,7 @@ namespace dataset {
class Dataset;
class SamplerObj;
class NodePass;
class DatasetSizeGetter;
#define RETURN_EMPTY_IF_ERROR(_s) \
do { \
@ -169,6 +171,14 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id);
/// \brief Gets the dataset size
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \return Status - The status code return
virtual Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size);
/// \brief Getter function for child nodes
/// \return Child nodes
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
@ -219,10 +229,13 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \notes Remove me after changing return val of Build()
Status BuildStatus() { return build_status; }
virtual bool IsSizeDefined() { return true; }
protected:
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_;
std::shared_ptr<DatasetCache> cache_;
int64_t dataset_size_ = -1;
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;

View File

@ -55,6 +55,8 @@ class FilterNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
bool IsSizeDefined() override { return false; };
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified

View File

@ -56,6 +56,23 @@ Status RepeatNode::ValidateParams() {
return Status::OK();
}
// Get Dataset size
Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
if (num_rows > 0 && repeat_count_ > 0) {
num_rows = num_rows * repeat_count_;
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
// Visitor accepting method for NodePass
Status RepeatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor

View File

@ -56,6 +56,15 @@ class RepeatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified

View File

@ -56,5 +56,22 @@ Status SkipNode::ValidateParams() {
return Status::OK();
}
// Get Dataset size
Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
*dataset_size = 0;
if (skip_count_ >= 0 && skip_count_ < num_rows) {
*dataset_size = num_rows - skip_count_;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -54,6 +54,15 @@ class SkipNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
int32_t skip_count_;
};

View File

@ -21,6 +21,7 @@
#include <string>
#include <utility>
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/util/status.h"
@ -87,5 +88,66 @@ Status CelebANode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
int64_t num_rows, sample_size;
std::ifstream partition_file;
std::string line;
Path folder_path(dataset_dir_);
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
if (!attr_file.is_open()) {
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
}
std::string rows_num;
(void)getline(attr_file, rows_num);
try {
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
}
if (usage_ != "all") {
int64_t partition_num = 0;
char usage_type;
if (usage_ == "train") {
usage_type = '0';
} else {
if (usage_ == "valid") {
usage_type = '1';
} else {
if (usage_ == "test")
usage_type = '2';
else
RETURN_STATUS_UNEXPECTED("Invalid usage.");
}
}
if (!partition_file.is_open()) {
partition_file.open((folder_path / "list_eval_partition.txt").toString());
}
if (partition_file.is_open()) {
while (getline(partition_file, line)) {
int start = line.find(' ');
if (line.at(start + 1) == usage_type) {
partition_num++;
}
}
} else {
std::string partition_file_name = "list_eval_partition.txt";
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open CelebA partition file: " + partition_file_name);
}
num_rows = std::min(num_rows, partition_num);
}
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -61,6 +61,15 @@ class CelebANode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -83,5 +83,20 @@ Status Cifar100Node::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -59,6 +59,15 @@ class Cifar100Node : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -81,5 +81,20 @@ Status Cifar10Node::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -59,6 +59,15 @@ class Cifar10Node : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -241,5 +241,21 @@ Status CLUENode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -61,6 +61,15 @@ class CLUENode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
/// \brief Split string based on a character delimiter
/// \return A string vector

View File

@ -134,5 +134,20 @@ Status CocoNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -59,6 +59,15 @@ class CocoNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_dir_;
std::string annotation_file_;

View File

@ -153,5 +153,21 @@ Status CSVNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -82,6 +82,15 @@ class CSVNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::vector<std::string> dataset_files_;
char field_delim_;

View File

@ -93,6 +93,5 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) {
*shard_id = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -64,6 +64,13 @@ class GeneratorNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Setter for DatasetSize in GeneratorNode
/// \param[in] sz dataset size to set
/// \return void
void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; }
bool IsSizeDefined() override { return false; }
private:
py::function generator_function_;
std::vector<std::string> column_names_;

View File

@ -89,5 +89,20 @@ Status ImageFolderNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {}));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -65,6 +65,15 @@ class ImageFolderNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_dir_;
bool decode_;

View File

@ -111,5 +111,21 @@ Status ManifestNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
int64_t num_classes; // dummy variable
RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -60,6 +60,15 @@ class ManifestNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_file_;
std::string usage_;

View File

@ -152,7 +152,6 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::vector<std::shared_ptr<ShardOperator>> operators_;
build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
@ -184,5 +183,28 @@ Status MindDataNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
std::vector<std::shared_ptr<ShardOperator>> operators;
RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators, num_padded_));
if (search_for_pattern_) {
dataset_files_ = {dataset_file_};
}
// The last operator is parent sampler
std::shared_ptr<ShardOperator> op = operators.back();
RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
*dataset_size = num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -74,6 +74,15 @@ class MindDataNode : public MappableSourceNode {
/// \note Pybind will use this function to set sample_bytes into MindDataNode
void SetSampleBytes(std::map<std::string, std::string> *sample_bytes);
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_file_; // search_for_pattern_ will be true in this mode
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
@ -83,6 +92,7 @@ class MindDataNode : public MappableSourceNode {
nlohmann::json padded_sample_;
std::map<std::string, std::string> sample_bytes_; // enable in python
int64_t num_padded_;
std::vector<std::shared_ptr<ShardOperator>> operators_;
};
} // namespace dataset

View File

@ -75,5 +75,20 @@ Status MnistNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -59,6 +59,15 @@ class MnistNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -86,17 +86,16 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
schema_file_path = schema_path_;
}
std::unique_ptr<DataSchema> data_schema;
std::vector<std::string> columns_to_load;
if (columns_list_.size() > 0) {
columns_to_load = columns_list_;
}
if (!schema_file_path.empty() || !schema_json_string.empty()) {
data_schema = std::make_unique<DataSchema>();
data_schema_ = std::make_unique<DataSchema>();
if (!schema_file_path.empty()) {
data_schema->LoadSchemaFile(schema_file_path, columns_to_load);
data_schema_->LoadSchemaFile(schema_file_path, columns_to_load);
} else if (!schema_json_string.empty()) {
data_schema->LoadSchemaString(schema_json_string, columns_to_load);
data_schema_->LoadSchemaString(schema_json_string, columns_to_load);
}
}
@ -109,7 +108,7 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema), std::move(sampler_->Build()));
std::move(data_schema_), std::move(sampler_->Build()));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
@ -125,5 +124,24 @@ Status RandomNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
if (sampler_ != nullptr) {
int64_t sample_size;
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
} else {
*dataset_size = num_rows;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -79,6 +79,15 @@ class RandomNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
/// \brief A quick inline for producing a random number between (and including) min/max
/// \param[in] min minimum number that can be generated.
@ -92,6 +101,7 @@ class RandomNode : public NonMappableSourceNode {
std::vector<std::string> columns_list_;
std::shared_ptr<SamplerObj> sampler_;
std::mt19937 rand_gen_;
std::unique_ptr<DataSchema> data_schema_;
};
} // namespace dataset

View File

@ -122,5 +122,20 @@ Status TextFileNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -61,6 +61,15 @@ class TextFileNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
std::vector<std::string> dataset_files_;
int32_t num_samples_;

View File

@ -169,5 +169,41 @@ Status TFRecordNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
if (!shard_equal_rows_) {
// Data will be sharded by file
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, 8, estimate));
} else {
// Data will be sharded by row
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, 8, estimate));
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
}
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
// Get the file list of the specific shard ID
Status TFRecordNode::GetShardFileList(std::vector<std::string> *shard_filenames) {
if (!shard_filenames->empty()) {
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.");
}
for (int index = 0; index < dataset_files_.size(); index++) {
if (index % num_shards_ == shard_id_) {
shard_filenames->push_back(dataset_files_.at(index));
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -88,6 +88,20 @@ class TFRecordNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Get the file list of the specific shard ID
/// \param[out] shard_filenames the list of filenames for that specific shard ID
/// \return Status of the function
Status GetShardFileList(std::vector<std::string> *shard_filenames);
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string

View File

@ -128,5 +128,20 @@ Status VOCNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}
// Get Dataset size
Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -61,6 +61,15 @@ class VOCNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
const std::string kColumnImage = "image";
const std::string kColumnTarget = "target";

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/util/status.h"
@ -56,5 +57,19 @@ Status TakeNode::ValidateParams() {
return Status::OK();
}
// Get Dataset size
Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
*dataset_size = std::min(static_cast<int64_t>(take_count_), num_rows);
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -54,6 +54,15 @@ class TakeNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
private:
int32_t take_count_;
};

Some files were not shown because too many files have changed in this diff Show More