!9249 Optimizing GetDatasetSize
From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:
This commit is contained in:
commit
df44e1339e
|
@ -199,12 +199,13 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
|
|||
// Constructor
|
||||
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
||||
|
||||
int64_t Dataset::GetDatasetSize() {
|
||||
int64_t Dataset::GetDatasetSize(bool estimate) {
|
||||
int64_t dataset_size;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
|
||||
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
|
||||
RETURN_SECOND_IF_ERROR(size_getter->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(size_getter->GetDatasetSize(&dataset_size, estimate), -1);
|
||||
return dataset_size;
|
||||
}
|
||||
|
||||
|
|
|
@ -106,19 +106,7 @@ PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) {
|
|||
}));
|
||||
|
||||
PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
||||
.def_static("get_num_rows_and_classes",
|
||||
[](const std::string &file, const py::dict &dict, const std::string &usage) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict,
|
||||
const std::string &usage) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp");
|
||||
}));
|
||||
PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
||||
|
@ -173,13 +161,6 @@ PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
|
||||
.def_static("get_num_rows",
|
||||
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
|
||||
return count;
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
|
||||
const std::string &task_mode, const py::dict &dict) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
|
|
|
@ -184,7 +184,11 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
|||
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
|
||||
THROW_IF_ERROR(gen->ValidateParams());
|
||||
return gen;
|
||||
}));
|
||||
}))
|
||||
.def("SetGeneratorDatasetSize", [](std::shared_ptr<GeneratorNode> self, int64_t sz) {
|
||||
self->SetGeneratorDatasetSize(sz);
|
||||
return self;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
||||
|
|
|
@ -93,12 +93,6 @@ PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
|
|||
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
|
||||
return output_class_indexing;
|
||||
})
|
||||
.def("GetDatasetSize",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t dataset_size;
|
||||
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
|
||||
return dataset_size;
|
||||
})
|
||||
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
|
||||
}));
|
||||
|
||||
|
@ -164,5 +158,18 @@ PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
|
|||
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonDatasetSizeGetter, TreeConsumer, std::shared_ptr<PythonDatasetSizeGetter>>(
|
||||
*m, "DatasetSizeGetters")
|
||||
.def(py::init<>())
|
||||
.def("Init", [](PythonDatasetSizeGetter &self,
|
||||
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("GetDatasetSize", [](PythonDatasetSizeGetter &self, bool estimate) {
|
||||
int64_t size;
|
||||
THROW_IF_ERROR(self.GetDatasetSize(&size, estimate));
|
||||
return size;
|
||||
});
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,4 +65,8 @@ Status PythonTreeGetters::GetRow(TensorRow *r) {
|
|||
py::gil_scoped_release gil_release;
|
||||
return TreeGetters::GetRow(r);
|
||||
}
|
||||
Status PythonDatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) {
|
||||
py::gil_scoped_release gil_release;
|
||||
return DatasetSizeGetter::GetRow(tree_adapter, r);
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -60,5 +60,9 @@ class PythonTreeGetters : public TreeGetters {
|
|||
public:
|
||||
Status GetRow(TensorRow *r) override;
|
||||
};
|
||||
class PythonDatasetSizeGetter : public DatasetSizeGetter {
|
||||
public:
|
||||
Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) override;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_
|
||||
|
|
|
@ -451,29 +451,6 @@ Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
|
|||
|
||||
Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }
|
||||
|
||||
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ == -1) {
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
||||
if (*dataset_size == -1) { // run through the tree and get everything
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(GetRow(&row));
|
||||
int64_t row_cnt = 0;
|
||||
while (!row.empty()) {
|
||||
++row_cnt;
|
||||
RETURN_IF_NOT_OK(GetRow(&row));
|
||||
}
|
||||
*dataset_size = row_cnt;
|
||||
}
|
||||
dataset_size_ = *dataset_size; // save the previous result
|
||||
}
|
||||
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
|
||||
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
|
||||
*types = first_row_type_;
|
||||
|
@ -573,5 +550,46 @@ Status BuildVocabConsumer::Start() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
|
||||
return Status::OK();
|
||||
}
|
||||
Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
|
||||
if (dataset_size_ == -1) {
|
||||
RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
|
||||
dataset_size_ = *size; // save the previous result
|
||||
}
|
||||
|
||||
*size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
|
||||
root_ = std::move(d);
|
||||
return Status::OK();
|
||||
}
|
||||
Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
|
||||
std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>();
|
||||
tree_adapters_.push_back(tree_adapter);
|
||||
tree_adapter->SetPrePassOverride([](OptPass pre) {
|
||||
pre.push_back(
|
||||
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
|
||||
return pre;
|
||||
});
|
||||
RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1));
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
|
||||
int64_t row_cnt = 0;
|
||||
while (!row.empty()) {
|
||||
++row_cnt;
|
||||
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
|
||||
}
|
||||
*dataset_size = row_cnt;
|
||||
return Status::OK();
|
||||
}
|
||||
Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
|
||||
return tree_adapter->GetNext(row);
|
||||
}
|
||||
Status DatasetSizeGetter::Terminate() {
|
||||
for (const auto &tree : tree_adapters_) {
|
||||
RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -177,7 +177,6 @@ class TreeGetters : public TreeConsumer {
|
|||
~TreeGetters() = default;
|
||||
Status Init(std::shared_ptr<DatasetNode> d) override;
|
||||
|
||||
Status GetDatasetSize(int64_t *size);
|
||||
Status GetOutputTypes(std::vector<DataType> *types);
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes);
|
||||
Status GetBatchSize(int64_t *batch_size);
|
||||
|
@ -186,7 +185,7 @@ class TreeGetters : public TreeConsumer {
|
|||
Status GetColumnNames(std::vector<std::string> *output);
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
virtual Status GetRow(TensorRow *r);
|
||||
virtual Status GetRow(TensorRow *row);
|
||||
|
||||
private:
|
||||
Status GetFirstRowShapeAndType();
|
||||
|
@ -202,6 +201,35 @@ class TreeGetters : public TreeConsumer {
|
|||
Status InternalInit();
|
||||
};
|
||||
|
||||
/// Consumer that is used to get some pipeline information
|
||||
class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this<DatasetSizeGetter> {
|
||||
public:
|
||||
DatasetSizeGetter() : dataset_size_(-1) {}
|
||||
~DatasetSizeGetter() = default;
|
||||
Status Init(std::shared_ptr<DatasetNode> d) override;
|
||||
Status Terminate() override;
|
||||
|
||||
/// \brief Function to get the dataset size
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *size, bool estimate = false);
|
||||
|
||||
virtual Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row);
|
||||
std::string Name() override { return "DatasetSizeGetter"; }
|
||||
|
||||
/// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node
|
||||
/// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate
|
||||
/// \return Status - The status code return
|
||||
Status DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size);
|
||||
|
||||
private:
|
||||
std::shared_ptr<DatasetNode> root_;
|
||||
std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_;
|
||||
int64_t dataset_size_;
|
||||
};
|
||||
|
||||
class BuildVocabConsumer : public TreeConsumer {
|
||||
public:
|
||||
/// BuildVocabConsumer Constructor which will call the base class default constructor.
|
||||
|
|
|
@ -537,30 +537,6 @@ Status BatchOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (batch_size_func_) {
|
||||
*dataset_size = -1;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
if (num_rows > 0 && start_batch_size_ > 0) {
|
||||
if (drop_) {
|
||||
num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * start_batch_size_)));
|
||||
} else {
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * start_batch_size_)));
|
||||
}
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t BatchOp::GetTreeBatchSize() {
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (batch_size_func_) {
|
||||
|
|
|
@ -225,11 +225,6 @@ class BatchOp : public ParallelOp {
|
|||
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map);
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
int64_t GetTreeBatchSize() override;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -232,12 +232,5 @@ Status BucketBatchByLengthOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -112,11 +112,6 @@ class BucketBatchByLengthOp : public PipelineOp {
|
|||
|
||||
std::string Name() const override { return kBucketBatchByLengthOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// << Stream output operator overload
|
||||
// @notes This allows you to write the debug print info using stream operators
|
||||
// @param out - reference to the output stream being overloaded
|
||||
|
|
|
@ -196,12 +196,5 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
|
|||
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ConcatOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -111,11 +111,6 @@ class ConcatOp : public PipelineOp {
|
|||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);
|
||||
|
||||
|
|
|
@ -294,24 +294,6 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Gets the dataset size
|
||||
Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
if (child_.size() == 1) {
|
||||
return child_[0]->GetDatasetSize(dataset_size);
|
||||
} else if (child_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get the dataset size from the last child.
|
||||
return child_[child_.size() - 1]->GetDatasetSize(dataset_size);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the number of classes
|
||||
Status DatasetOp::GetNumClasses(int64_t *num_classes) {
|
||||
if (child_.size() == 1) {
|
||||
|
|
|
@ -180,10 +180,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Status - The error code return
|
||||
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
|
||||
|
||||
/// \brief Gets the dataset size
|
||||
/// \return Status - The status code return
|
||||
virtual Status GetDatasetSize(int64_t *dataset_size);
|
||||
|
||||
/// \brief Gets the batch size
|
||||
/// \return Status - The status code return
|
||||
virtual int64_t GetTreeBatchSize();
|
||||
|
|
|
@ -258,13 +258,5 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) {
|
|||
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status FilterOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -137,11 +137,6 @@ class FilterOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kFilterOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
std::shared_ptr<TensorOp> predicate_func_;
|
||||
|
|
|
@ -187,21 +187,6 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
if (num_rows > 0 && num_repeats_ > 0) {
|
||||
num_rows = num_rows * num_repeats_;
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,11 +133,6 @@ class RepeatOp : public PipelineOp {
|
|||
/// \@return Status - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
int64_t GetTreeRepeatCount() override;
|
||||
|
||||
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
|
|
|
@ -136,20 +136,5 @@ Status SkipOp::PreAccept(NodePass *p, bool *modified) {
|
|||
return p->PreRunOnNode(shared_from_base<SkipOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status SkipOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
*dataset_size = 0;
|
||||
if (max_skips_ >= 0 && max_skips_ < num_rows) {
|
||||
*dataset_size = num_rows - max_skips_;
|
||||
}
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -86,11 +86,6 @@ class SkipOp : public PipelineOp {
|
|||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return kSkipOp; }
|
||||
|
|
|
@ -452,63 +452,5 @@ Status CelebAOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CelebAOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
int64_t num_rows, sample_size;
|
||||
std::string line;
|
||||
Path folder_path(folder_path_);
|
||||
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
|
||||
if (!attr_file.is_open()) {
|
||||
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
|
||||
}
|
||||
|
||||
std::string rows_num;
|
||||
(void)getline(attr_file, rows_num);
|
||||
try {
|
||||
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
|
||||
}
|
||||
if (usage_ != "all") {
|
||||
int64_t partition_num = 0;
|
||||
char usage_type;
|
||||
if (usage_ == "train") {
|
||||
usage_type = '0';
|
||||
} else {
|
||||
if (usage_ == "valid") {
|
||||
usage_type = '1';
|
||||
} else {
|
||||
if (usage_ == "test")
|
||||
usage_type = '2';
|
||||
else
|
||||
RETURN_STATUS_UNEXPECTED("Invalid usage.");
|
||||
}
|
||||
}
|
||||
if (!partition_file_.is_open()) {
|
||||
partition_file_.open((folder_path / "list_eval_partition.txt").toString());
|
||||
}
|
||||
if (partition_file_.is_open()) {
|
||||
while (getline(partition_file_, line)) {
|
||||
int start = line.find(' ');
|
||||
if (line.at(start + 1) == usage_type) {
|
||||
partition_num++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::string partition_file_name = "list_eval_partition.txt";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name);
|
||||
}
|
||||
num_rows = std::min(num_rows, partition_num);
|
||||
}
|
||||
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -179,11 +179,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CelebAOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Called first when function is called
|
||||
// @return
|
||||
|
|
|
@ -508,20 +508,5 @@ Status CifarOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CifarOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0)
|
||||
RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows));
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -175,11 +175,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CifarOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -565,19 +565,5 @@ Status ClueOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ClueOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
sample_size = num_samples_;
|
||||
num_rows = num_rows_per_shard_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -197,11 +197,6 @@ class ClueOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -681,39 +681,6 @@ Status CocoOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
std::string task_type;
|
||||
switch (task_type_) {
|
||||
case TaskType::Detection:
|
||||
task_type = "Detection";
|
||||
break;
|
||||
case TaskType::Keypoint:
|
||||
task_type = "Keypoint";
|
||||
break;
|
||||
case TaskType::Panoptic:
|
||||
task_type = "Panoptic";
|
||||
break;
|
||||
case TaskType::Stuff:
|
||||
task_type = "Stuff";
|
||||
break;
|
||||
}
|
||||
if (image_ids_.size() == 0) {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
|
||||
} else {
|
||||
num_rows = image_ids_.size();
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if ((*output_class_indexing).empty()) {
|
||||
if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) {
|
||||
|
|
|
@ -213,11 +213,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CocoOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Gets the class indexing
|
||||
/// \return Status - The status code return
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
|
||||
|
|
|
@ -916,19 +916,5 @@ Status CsvOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CsvOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
sample_size = num_samples_;
|
||||
num_rows = num_rows_per_shard_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -318,11 +318,6 @@ class CsvOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -274,11 +274,5 @@ Status GeneratorOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -136,8 +136,6 @@ class GeneratorOp : public PipelineOp {
|
|||
|
||||
Status Init();
|
||||
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
std::vector<std::string> column_names_;
|
||||
|
|
|
@ -465,24 +465,6 @@ Status ImageFolderOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t sample_size, num_rows;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) {
|
||||
// GetDatasetSize will not be impacted by class_index_
|
||||
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {}));
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get number of classes
|
||||
Status ImageFolderOp::GetNumClasses(int64_t *num_classes) {
|
||||
if (num_classes_ > 0) {
|
||||
|
|
|
@ -217,11 +217,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "ImageFolderOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Base-class override for GetNumClasses
|
||||
/// \param[out] num_classes the number of classes
|
||||
/// \return Status of the function
|
||||
|
|
|
@ -396,16 +396,9 @@ Status ManifestOp::CountDatasetInfo() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||
int64_t *count, int64_t *numClasses) {
|
||||
Status ManifestOp::CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map,
|
||||
const std::string &usage, int64_t *count, int64_t *numClasses) {
|
||||
// the logic of counting the number of samples is copied from ParseManifestFile()
|
||||
std::map<std::string, int32_t> map;
|
||||
for (auto p : dict) {
|
||||
(void)map.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
|
||||
py::reinterpret_borrow<py::int_>(p.second)));
|
||||
}
|
||||
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op));
|
||||
|
@ -415,6 +408,7 @@ Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||
std::map<std::string, int32_t> *output_class_indexing) {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
|
@ -459,23 +453,6 @@ Status ManifestOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ManifestOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseManifestFile());
|
||||
num_rows = static_cast<int64_t>(op->image_labelname_.size());
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get number of classes
|
||||
Status ManifestOp::GetNumClasses(int64_t *num_classes) {
|
||||
if (num_classes_ > 0) {
|
||||
|
|
|
@ -164,10 +164,17 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count,
|
||||
int64_t *numClasses);
|
||||
/// \brief Counts the total number of rows in Manifest
|
||||
/// \param[in] file Dataset file path
|
||||
/// \param[in] input_class_indexing Input map of class index
|
||||
/// \param[in] usage Dataset usage
|
||||
/// \param[out] count Number of rows counted
|
||||
/// \param[out] numClasses Number of classes counted
|
||||
/// \return Status of the function
|
||||
static Status CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map,
|
||||
const std::string &usage, int64_t *count, int64_t *numClasses);
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Get str-to-int mapping from label name to index
|
||||
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||
std::map<std::string, int32_t> *output_class_indexing);
|
||||
|
@ -183,11 +190,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "ManifestOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Base-class override for GetNumClasses
|
||||
/// \param[out] num_classes the number of classes
|
||||
/// \return Status of the function
|
||||
|
|
|
@ -474,22 +474,5 @@ Status MindRecordOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) {
|
||||
// The last operator is parent sampler
|
||||
std::shared_ptr<ShardOperator> op = operators_.back();
|
||||
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -212,11 +212,6 @@ class MindRecordOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "MindRecordOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
|
||||
|
||||
|
|
|
@ -471,19 +471,5 @@ Status MnistOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status MnistOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows));
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -168,11 +168,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "MnistOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -421,23 +421,5 @@ Status RandomDataOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
if (sampler_ != nullptr) {
|
||||
int64_t sample_size;
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
} else {
|
||||
*dataset_size = num_rows;
|
||||
}
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -203,11 +203,6 @@ class RandomDataOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "RandomDataOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* The entry point code for when workers are launched
|
||||
|
|
|
@ -162,11 +162,11 @@ Status DistributedSamplerRT::ResetSampler() {
|
|||
}
|
||||
|
||||
int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
||||
int64_t childs = num_rows;
|
||||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
childs = child_[0]->CalculateNumSamples(num_rows);
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
}
|
||||
int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs;
|
||||
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
return std::ceil(num_samples * 1.0 / num_devices_);
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,11 @@ class DistributedSamplerRT : public SamplerRT {
|
|||
|
||||
int64_t GetDeviceNum() { return num_devices_; }
|
||||
|
||||
/// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers
|
||||
/// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows,
|
||||
/// then num_samples_ is not returned at all.
|
||||
/// \param[in] num_rows The total number of rows in the dataset
|
||||
/// \return int64_t Calculated number of samples
|
||||
int64_t CalculateNumSamples(int64_t num_rows) override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
|
|
@ -520,19 +520,5 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
sample_size = total_rows_;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = num_rows_per_shard_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -198,11 +198,6 @@ class TextFileOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -1067,41 +1067,5 @@ Status TFReaderOp::PrepareNodePostAction() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the file list of the specific shard ID
|
||||
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
|
||||
if (!shard_filenames->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
|
||||
}
|
||||
for (int index = 0; index < dataset_files_list_.size(); index++) {
|
||||
if (index % num_devices_ == device_id_) {
|
||||
shard_filenames->push_back(dataset_files_list_.at(index));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) {
|
||||
if (equal_rows_per_shard_) {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = num_rows_per_shard_;
|
||||
} else {
|
||||
std::vector<std::string> shard_file_list;
|
||||
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
|
||||
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
|
||||
}
|
||||
}
|
||||
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -257,11 +257,6 @@ class TFReaderOp : public ParallelOp {
|
|||
// before providing their own implementations.
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
static bool ValidateFirstRowCrc(const std::string &filename);
|
||||
|
||||
private:
|
||||
|
@ -400,11 +395,6 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
|
||||
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
|
||||
// @return - Status - the status code returned.
|
||||
Status GetShardFileList(std::vector<std::string> *shard_filenames);
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
|
|
|
@ -447,16 +447,9 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t *count) {
|
||||
const std::map<std::string, int32_t> &input_class_indexing, int64_t *count) {
|
||||
if (task_type == "Detection") {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
for (auto p : dict) {
|
||||
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
|
||||
py::reinterpret_borrow<py::int_>(p.second)));
|
||||
}
|
||||
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
|
||||
|
@ -473,6 +466,7 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
|
@ -516,36 +510,6 @@ Status VOCOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
if (image_ids_.size() == 0) {
|
||||
if (task_type_ == TaskType::Detection) {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
num_rows = static_cast<int64_t>(op->image_ids_.size());
|
||||
} else if (task_type_ == TaskType::Segmentation) {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
num_rows = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
} else {
|
||||
num_rows = image_ids_.size();
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if ((*output_class_indexing).empty()) {
|
||||
if (task_type_ != TaskType::Detection) {
|
||||
|
|
|
@ -187,15 +187,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// @param const std::string &dir - VOC dir path
|
||||
// @param const std::string &task_type - task type of reading voc job
|
||||
// @param const std::string &task_mode - task mode of reading voc job
|
||||
// @param const py::dict &dict - input dict of class index
|
||||
// @param const std::map<std::string, int32_t> input_class_indexing - input map of class index
|
||||
// @param int64_t *count - output rows number of VOCDataset
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t *count);
|
||||
const std::map<std::string, int32_t> &input_class_indexing, int64_t *count);
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// @param const std::string &dir - VOC dir path
|
||||
// @param const std::string &task_type - task type of reading voc job
|
||||
// @param const std::string &task_mode - task mode of reading voc job
|
||||
|
@ -216,11 +216,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "VOCOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// /// \brief Gets the class indexing
|
||||
// /// \return Status - The status code return
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
|
||||
|
|
|
@ -139,17 +139,5 @@ Status TakeOp::PreAccept(NodePass *p, bool *modified) {
|
|||
return p->PreRunOnNode(shared_from_base<TakeOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TakeOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
*dataset_size = std::min(static_cast<int64_t>(max_takes_), num_rows);
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -94,11 +94,6 @@ class TakeOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kTakeOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
int32_t max_takes_; // The number of takes that the user requested
|
||||
int32_t take_count_; // A counter for the current number of executed takes
|
||||
|
|
|
@ -248,24 +248,6 @@ Status ZipOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<ZipOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ZipOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<int32_t> dataset_sizes;
|
||||
int64_t child_dataset_size;
|
||||
for (auto child : child_) {
|
||||
RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size));
|
||||
dataset_sizes.push_back(child_dataset_size);
|
||||
}
|
||||
|
||||
*dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end());
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ZipOp::ComputeColMap() {
|
||||
if (column_name_id_map_.empty()) {
|
||||
column_name_id_map_ = {};
|
||||
|
|
|
@ -120,11 +120,6 @@ class ZipOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kZipOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status prepare(TensorQTable *const table);
|
||||
|
|
|
@ -114,5 +114,33 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (batch_size_func_) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
|
||||
if (num_rows > 0 && batch_size_ > 0) {
|
||||
if (drop_remainder_) {
|
||||
num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * batch_size_)));
|
||||
} else {
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * batch_size_)));
|
||||
}
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,6 +64,15 @@ class BatchNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
int32_t batch_size_;
|
||||
bool drop_remainder_;
|
||||
|
|
|
@ -127,6 +127,5 @@ Status BucketBatchByLengthNode::ValidateParams() {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,8 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
bool IsSizeDefined() override { return false; };
|
||||
|
||||
private:
|
||||
std::vector<std::string> column_names_;
|
||||
std::vector<int32_t> bucket_boundaries_;
|
||||
|
|
|
@ -58,6 +58,8 @@ class ConcatNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
bool IsSizeDefined() override { return false; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
|
|
|
@ -342,6 +342,31 @@ Status DatasetNode::GetShardId(int32_t *shard_id) {
|
|||
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the dataset size
|
||||
Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
if (!IsSizeDefined()) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
if (children_.size() == 1) {
|
||||
return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size);
|
||||
} else if (children_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get the dataset size from the last child.
|
||||
return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
|
||||
}
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status SourceNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -32,6 +33,7 @@ namespace dataset {
|
|||
class Dataset;
|
||||
class SamplerObj;
|
||||
class NodePass;
|
||||
class DatasetSizeGetter;
|
||||
|
||||
#define RETURN_EMPTY_IF_ERROR(_s) \
|
||||
do { \
|
||||
|
@ -169,6 +171,14 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
virtual Status GetShardId(int32_t *shard_id);
|
||||
|
||||
/// \brief Gets the dataset size
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \return Status - The status code return
|
||||
virtual Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size);
|
||||
|
||||
/// \brief Getter function for child nodes
|
||||
/// \return Child nodes
|
||||
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
|
||||
|
@ -219,10 +229,13 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \notes Remove me after changing return val of Build()
|
||||
Status BuildStatus() { return build_status; }
|
||||
|
||||
virtual bool IsSizeDefined() { return true; }
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<DatasetNode>> children_;
|
||||
DatasetNode *parent_;
|
||||
std::shared_ptr<DatasetCache> cache_;
|
||||
int64_t dataset_size_ = -1;
|
||||
int32_t num_workers_;
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t connector_que_size_;
|
||||
|
|
|
@ -55,6 +55,8 @@ class FilterNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
bool IsSizeDefined() override { return false; };
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
|
|
|
@ -56,6 +56,23 @@ Status RepeatNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
|
||||
if (num_rows > 0 && repeat_count_ > 0) {
|
||||
num_rows = num_rows * repeat_count_;
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status RepeatNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -56,6 +56,15 @@ class RepeatNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
|
|
|
@ -56,5 +56,22 @@ Status SkipNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
|
||||
*dataset_size = 0;
|
||||
if (skip_count_ >= 0 && skip_count_ < num_rows) {
|
||||
*dataset_size = num_rows - skip_count_;
|
||||
}
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,6 +54,15 @@ class SkipNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
@ -87,5 +88,66 @@ Status CelebANode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
int64_t num_rows, sample_size;
|
||||
std::ifstream partition_file;
|
||||
std::string line;
|
||||
Path folder_path(dataset_dir_);
|
||||
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
|
||||
if (!attr_file.is_open()) {
|
||||
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
|
||||
}
|
||||
|
||||
std::string rows_num;
|
||||
(void)getline(attr_file, rows_num);
|
||||
try {
|
||||
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
|
||||
}
|
||||
if (usage_ != "all") {
|
||||
int64_t partition_num = 0;
|
||||
char usage_type;
|
||||
if (usage_ == "train") {
|
||||
usage_type = '0';
|
||||
} else {
|
||||
if (usage_ == "valid") {
|
||||
usage_type = '1';
|
||||
} else {
|
||||
if (usage_ == "test")
|
||||
usage_type = '2';
|
||||
else
|
||||
RETURN_STATUS_UNEXPECTED("Invalid usage.");
|
||||
}
|
||||
}
|
||||
if (!partition_file.is_open()) {
|
||||
partition_file.open((folder_path / "list_eval_partition.txt").toString());
|
||||
}
|
||||
if (partition_file.is_open()) {
|
||||
while (getline(partition_file, line)) {
|
||||
int start = line.find(' ');
|
||||
if (line.at(start + 1) == usage_type) {
|
||||
partition_num++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::string partition_file_name = "list_eval_partition.txt";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open CelebA partition file: " + partition_file_name);
|
||||
}
|
||||
num_rows = std::min(num_rows, partition_num);
|
||||
}
|
||||
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,15 @@ class CelebANode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -83,5 +83,20 @@ Status Cifar100Node::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,6 +59,15 @@ class Cifar100Node : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -81,5 +81,20 @@ Status Cifar10Node::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,6 +59,15 @@ class Cifar10Node : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -241,5 +241,21 @@ Status CLUENode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
|
||||
sample_size = num_samples_;
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,15 @@ class CLUENode : public NonMappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// \return A string vector
|
||||
|
|
|
@ -134,5 +134,20 @@ Status CocoNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,6 +59,15 @@ class CocoNode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string annotation_file_;
|
||||
|
|
|
@ -153,5 +153,21 @@ Status CSVNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows));
|
||||
sample_size = num_samples_;
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -82,6 +82,15 @@ class CSVNode : public NonMappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
char field_delim_;
|
||||
|
|
|
@ -93,6 +93,5 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) {
|
|||
*shard_id = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,6 +64,13 @@ class GeneratorNode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Setter for DatasetSize in GeneratorNode
|
||||
/// \param[in] sz dataset size to set
|
||||
/// \return void
|
||||
void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; }
|
||||
|
||||
bool IsSizeDefined() override { return false; }
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
std::vector<std::string> column_names_;
|
||||
|
|
|
@ -89,5 +89,20 @@ Status ImageFolderNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t sample_size, num_rows;
|
||||
RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {}));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,6 +65,15 @@ class ImageFolderNode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
bool decode_;
|
||||
|
|
|
@ -111,5 +111,21 @@ Status ManifestNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
int64_t num_classes; // dummy variable
|
||||
RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,15 @@ class ManifestNode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_file_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -152,7 +152,6 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
|||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_;
|
||||
build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
|
@ -184,5 +183,28 @@ Status MindDataNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators;
|
||||
RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators, num_padded_));
|
||||
|
||||
if (search_for_pattern_) {
|
||||
dataset_files_ = {dataset_file_};
|
||||
}
|
||||
|
||||
// The last operator is parent sampler
|
||||
std::shared_ptr<ShardOperator> op = operators.back();
|
||||
RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,6 +74,15 @@ class MindDataNode : public MappableSourceNode {
|
|||
/// \note Pybind will use this function to set sample_bytes into MindDataNode
|
||||
void SetSampleBytes(std::map<std::string, std::string> *sample_bytes);
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_file_; // search_for_pattern_ will be true in this mode
|
||||
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
|
||||
|
@ -83,6 +92,7 @@ class MindDataNode : public MappableSourceNode {
|
|||
nlohmann::json padded_sample_;
|
||||
std::map<std::string, std::string> sample_bytes_; // enable in python
|
||||
int64_t num_padded_;
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -75,5 +75,20 @@ Status MnistNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,6 +59,15 @@ class MnistNode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -86,17 +86,16 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
|||
schema_file_path = schema_path_;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataSchema> data_schema;
|
||||
std::vector<std::string> columns_to_load;
|
||||
if (columns_list_.size() > 0) {
|
||||
columns_to_load = columns_list_;
|
||||
}
|
||||
if (!schema_file_path.empty() || !schema_json_string.empty()) {
|
||||
data_schema = std::make_unique<DataSchema>();
|
||||
data_schema_ = std::make_unique<DataSchema>();
|
||||
if (!schema_file_path.empty()) {
|
||||
data_schema->LoadSchemaFile(schema_file_path, columns_to_load);
|
||||
data_schema_->LoadSchemaFile(schema_file_path, columns_to_load);
|
||||
} else if (!schema_json_string.empty()) {
|
||||
data_schema->LoadSchemaString(schema_json_string, columns_to_load);
|
||||
data_schema_->LoadSchemaString(schema_json_string, columns_to_load);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,7 +108,7 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
|||
|
||||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema), std::move(sampler_->Build()));
|
||||
std::move(data_schema_), std::move(sampler_->Build()));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
|
@ -125,5 +124,24 @@ Status RandomNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
if (sampler_ != nullptr) {
|
||||
int64_t sample_size;
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
} else {
|
||||
*dataset_size = num_rows;
|
||||
}
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,6 +79,15 @@ class RandomNode : public NonMappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
/// \brief A quick inline for producing a random number between (and including) min/max
|
||||
/// \param[in] min minimum number that can be generated.
|
||||
|
@ -92,6 +101,7 @@ class RandomNode : public NonMappableSourceNode {
|
|||
std::vector<std::string> columns_list_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::mt19937 rand_gen_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -122,5 +122,20 @@ Status TextFileNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size = num_samples_;
|
||||
RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,15 @@ class TextFileNode : public NonMappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
int32_t num_samples_;
|
||||
|
|
|
@ -169,5 +169,41 @@ Status TFRecordNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
if (!shard_equal_rows_) {
|
||||
// Data will be sharded by file
|
||||
std::vector<std::string> shard_file_list;
|
||||
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, 8, estimate));
|
||||
} else {
|
||||
// Data will be sharded by row
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, 8, estimate));
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
|
||||
}
|
||||
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the file list of the specific shard ID
|
||||
Status TFRecordNode::GetShardFileList(std::vector<std::string> *shard_filenames) {
|
||||
if (!shard_filenames->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.");
|
||||
}
|
||||
for (int index = 0; index < dataset_files_.size(); index++) {
|
||||
if (index % num_shards_ == shard_id_) {
|
||||
shard_filenames->push_back(dataset_files_.at(index));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,6 +88,20 @@ class TFRecordNode : public NonMappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Get the file list of the specific shard ID
|
||||
/// \param[out] shard_filenames the list of filenames for that specific shard ID
|
||||
/// \return Status of the function
|
||||
Status GetShardFileList(std::vector<std::string> *shard_filenames);
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
|
||||
|
|
|
@ -128,5 +128,20 @@ Status VOCNode::GetShardId(int32_t *shard_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,15 @@ class VOCNode : public MappableSourceNode {
|
|||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
const std::string kColumnImage = "image";
|
||||
const std::string kColumnTarget = "target";
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/take_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
@ -56,5 +57,19 @@ Status TakeNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
|
||||
*dataset_size = std::min(static_cast<int64_t>(take_count_), num_rows);
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,6 +54,15 @@ class TakeNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
int32_t take_count_;
|
||||
};
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue