forked from mindspore-Ecosystem/mindspore
Added GetDatasetSize
This commit is contained in:
parent
7276198580
commit
0e03f5b0dd
|
@ -179,6 +179,26 @@ Dataset::Dataset() {
|
|||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
connector_que_size_ = cfg->op_connector_size();
|
||||
worker_connector_size_ = cfg->worker_connector_size();
|
||||
tree_getters_ = std::make_shared<TreeGetters>();
|
||||
}
|
||||
|
||||
int64_t Dataset::GetDatasetSize() {
|
||||
int64_t dataset_size;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
||||
return rc.IsError() ? -1 : dataset_size;
|
||||
}
|
||||
|
||||
// Constructor to initialize the cache
|
||||
|
|
|
@ -351,4 +351,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
|
|||
}
|
||||
#endif
|
||||
|
||||
TreeGetters::TreeGetters() {
|
||||
tree_adapter_ = std::make_unique<TreeAdapter>();
|
||||
dataset_size_ = -1;
|
||||
}
|
||||
|
||||
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); }
|
||||
|
||||
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ == -1) {
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
TensorRow row;
|
||||
if (*dataset_size == -1) {
|
||||
int64_t num_rows = 0;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
while (row.size() != 0) {
|
||||
num_rows++;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
}
|
||||
dataset_size_ = num_rows;
|
||||
}
|
||||
}
|
||||
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -152,9 +152,10 @@ class ToDevice : public TreeConsumer {
|
|||
|
||||
/// Consumer that is used to get some pipeline information
|
||||
class TreeGetters : public TreeConsumer {
|
||||
Status GetDatasetSize(int32_t *size) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
public:
|
||||
TreeGetters();
|
||||
Status Init(std::shared_ptr<api::Dataset> d) override;
|
||||
Status GetDatasetSize(int64_t *size);
|
||||
Status GetBatchSize(int32_t *batch_size) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
|
@ -173,6 +174,11 @@ class TreeGetters : public TreeConsumer {
|
|||
Status GetOutputNames(std::vector<std::string> *names) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
|
||||
private:
|
||||
int64_t dataset_size_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -531,5 +531,30 @@ Status BatchOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (batch_size_func_) {
|
||||
*dataset_size = -1;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
if (num_rows > 0 && start_batch_size_ > 0) {
|
||||
if (drop_) {
|
||||
num_rows = floor(num_rows / start_batch_size_);
|
||||
} else {
|
||||
num_rows = ceil(num_rows / start_batch_size_);
|
||||
}
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -219,6 +219,11 @@ class BatchOp : public ParallelOp {
|
|||
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map);
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
protected:
|
||||
Status ComputeColMap() override;
|
||||
|
||||
|
|
|
@ -231,5 +231,13 @@ Status BucketBatchByLengthOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -112,6 +112,11 @@ class BucketBatchByLengthOp : public PipelineOp {
|
|||
|
||||
std::string Name() const override { return kBucketBatchByLengthOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// << Stream output operator overload
|
||||
// @notes This allows you to write the debug print info using stream operators
|
||||
// @param out - reference to the output stream being overloaded
|
||||
|
|
|
@ -195,5 +195,13 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ConcatOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp {
|
|||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);
|
||||
|
||||
|
|
|
@ -50,7 +50,8 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler
|
|||
op_num_repeats_per_epoch_(kInfiniteRepeat),
|
||||
op_current_repeats_(0),
|
||||
op_current_epochs_(0),
|
||||
out_connector_(nullptr) {
|
||||
out_connector_(nullptr),
|
||||
dataset_size_(-1) {
|
||||
// The operator starts out with an invalid operator id. The only way to
|
||||
// get it out of invalid state is to assign the operator to an execution tree.
|
||||
}
|
||||
|
@ -290,6 +291,17 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Gets the dataset size
|
||||
Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree.");
|
||||
|
||||
return child_[0]->GetDatasetSize(dataset_size);
|
||||
}
|
||||
|
||||
// Performs handling for when an eoe message is received.
|
||||
// The base class implementation simply flows the eoe message to output. Derived classes
|
||||
// may override if they need to perform special eoe handling.
|
||||
|
|
|
@ -179,6 +179,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Status - The error code return
|
||||
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
|
||||
|
||||
/// \brief Gets the dataset size
|
||||
/// \return Status - The status code return
|
||||
virtual Status GetDatasetSize(int64_t *dataset_size);
|
||||
|
||||
/// \brief Performs handling for when an eoe message is received.
|
||||
/// The base class implementation simply flows the eoe message to output. Derived classes
|
||||
/// may override if they need to perform special eoe handling.
|
||||
|
@ -406,6 +410,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
|
||||
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
|
||||
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
|
||||
int64_t dataset_size_; // Size of the dataset
|
||||
|
||||
private:
|
||||
/// Sets the operator id.
|
||||
|
|
|
@ -278,5 +278,14 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status FilterOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
|
||||
// iterate over the dataset and count the size
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -137,6 +137,11 @@ class FilterOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kFilterOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
py::function predicate_func_;
|
||||
|
|
|
@ -191,5 +191,21 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0 || num_repeats_ == -1) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
if (num_rows > 0 && num_repeats_ > 0) {
|
||||
num_rows = num_rows * num_repeats_;
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,6 +133,11 @@ class RepeatOp : public PipelineOp {
|
|||
/// \@return Status - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
|
||||
|
|
|
@ -134,5 +134,21 @@ Status SkipOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<SkipOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status SkipOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
*dataset_size = 0;
|
||||
if (max_skips_ >= 0 && max_skips_ < num_rows) {
|
||||
*dataset_size = num_rows - max_skips_;
|
||||
}
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,6 +80,11 @@ class SkipOp : public PipelineOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return kSkipOp; }
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
|
@ -445,5 +446,64 @@ Status CelebAOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CelebAOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
int64_t num_rows, sample_size;
|
||||
std::string line;
|
||||
Path folder_path(folder_path_);
|
||||
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
|
||||
if (!attr_file.is_open()) {
|
||||
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
|
||||
}
|
||||
|
||||
std::string rows_num;
|
||||
(void)getline(attr_file, rows_num);
|
||||
try {
|
||||
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
|
||||
}
|
||||
if (usage_ != "all") {
|
||||
int64_t partition_num = 0;
|
||||
char usage_type;
|
||||
if (usage_ == "train") {
|
||||
usage_type = '0';
|
||||
} else {
|
||||
if (usage_ == "valid") {
|
||||
usage_type = '1';
|
||||
} else {
|
||||
if (usage_ == "test")
|
||||
usage_type = '2';
|
||||
else
|
||||
RETURN_STATUS_UNEXPECTED("Invalid usage.");
|
||||
}
|
||||
}
|
||||
if (!partition_file_.is_open()) {
|
||||
partition_file_.open((folder_path / "list_eval_partition.txt").toString());
|
||||
}
|
||||
if (partition_file_.is_open()) {
|
||||
while (getline(partition_file_, line)) {
|
||||
int start = line.find(' ');
|
||||
if (line.at(start + 1) == usage_type) {
|
||||
partition_num++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::string partition_file_name = "list_eval_partition.txt";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name);
|
||||
}
|
||||
num_rows = std::min(num_rows, partition_num);
|
||||
}
|
||||
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -179,6 +179,11 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CelebAOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Called first when function is called
|
||||
// @return
|
||||
|
|
|
@ -507,5 +507,21 @@ Status CifarOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CifarOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0)
|
||||
RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows));
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -175,6 +175,11 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CifarOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
@ -563,5 +564,20 @@ Status ClueOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ClueOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
sample_size = num_samples_;
|
||||
num_rows = num_rows_per_shard_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -193,6 +193,11 @@ class ClueOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -679,5 +679,36 @@ Status CocoOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
std::string task_type;
|
||||
switch (task_type_) {
|
||||
case TaskType::Detection:
|
||||
task_type = "Detection";
|
||||
break;
|
||||
case TaskType::Keypoint:
|
||||
task_type = "Keypoint";
|
||||
break;
|
||||
case TaskType::Panoptic:
|
||||
task_type = "Panoptic";
|
||||
break;
|
||||
case TaskType::Stuff:
|
||||
task_type = "Stuff";
|
||||
break;
|
||||
}
|
||||
if (image_ids_.size() == 0) {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
|
||||
}
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -209,6 +209,11 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CocoOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <stdexcept>
|
||||
|
@ -914,5 +915,20 @@ Status CsvOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status CsvOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
sample_size = num_samples_;
|
||||
num_rows = num_rows_per_shard_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -318,6 +318,11 @@ class CsvOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -453,5 +453,20 @@ Status ImageFolderOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t sample_size, num_rows, num_classes;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, &num_classes));
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -213,6 +213,11 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "ImageFolderOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -453,5 +453,23 @@ Status ManifestOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ManifestOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseManifestFile());
|
||||
num_rows = static_cast<int64_t>(op->image_labelname_.size());
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -183,6 +183,11 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "ManifestOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using mindrecord::kInt64Len;
|
||||
using mindrecord::MSRStatus;
|
||||
using mindrecord::Schema;
|
||||
|
@ -476,5 +477,23 @@ Status MindRecordOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = num_rows_, sample_size;
|
||||
if (num_rows_ <= 0) {
|
||||
std::shared_ptr<ShardOperator> op;
|
||||
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
|
||||
}
|
||||
sample_size = operators_[0]->GetNumSamples(num_rows, 0);
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -212,6 +212,11 @@ class MindRecordOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "MindRecordOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
|
||||
|
||||
|
|
|
@ -470,5 +470,20 @@ Status MnistOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status MnistOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows));
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -168,6 +168,11 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "MnistOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <random>
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
|
@ -418,5 +420,19 @@ Status RandomDataOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size = 0;
|
||||
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
if (sampler_ != nullptr) sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -203,6 +203,11 @@ class RandomDataOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "RandomDataOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* The entry point code for when workers are launched
|
||||
|
|
|
@ -2,20 +2,20 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
|
||||
distributed_sampler.cc
|
||||
pk_sampler.cc
|
||||
random_sampler.cc
|
||||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
)
|
||||
distributed_sampler.cc
|
||||
pk_sampler.cc
|
||||
random_sampler.cc
|
||||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
|
||||
${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES}
|
||||
python_sampler.cc
|
||||
)
|
||||
endif()
|
||||
${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES}
|
||||
python_sampler.cc
|
||||
)
|
||||
endif ()
|
||||
|
||||
add_library(engine-datasetops-source-sampler OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES})
|
||||
|
|
|
@ -129,6 +129,8 @@ Status Sampler::SetNumSamples(int64_t num_samples) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t Sampler::GetNumSamples() { return num_samples_; }
|
||||
|
||||
Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0.");
|
||||
num_rows_ = num_rows;
|
||||
|
|
|
@ -98,6 +98,11 @@ class Sampler {
|
|||
// @return status error code
|
||||
Status SetNumSamples(int64_t num_samples);
|
||||
|
||||
// getter for num samples
|
||||
// @param num_samples - the number of samples to return.
|
||||
// @return status error code
|
||||
int64_t GetNumSamples();
|
||||
|
||||
// setter for num or records in the dataset
|
||||
// @param num_rows - the number of records
|
||||
// @return status error code
|
||||
|
|
|
@ -519,5 +519,20 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
sample_size = total_rows_;
|
||||
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = total_rows_;
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -198,6 +198,11 @@ class TextFileOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -1062,5 +1062,27 @@ Status TFReaderOp::PrepareNodePostAction() {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) {
|
||||
if (equal_rows_per_shard_) {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = num_rows_per_shard_;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_));
|
||||
}
|
||||
}
|
||||
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -257,6 +257,11 @@ class TFReaderOp : public ParallelOp {
|
|||
// before providing their own implementations.
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -513,5 +513,33 @@ Status VOCOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
if (image_ids_.size() == 0) {
|
||||
if (task_type_ == TaskType::Detection) {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
num_rows = static_cast<int64_t>(op->image_ids_.size());
|
||||
} else if (task_type_ == TaskType::Segmentation) {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
num_rows = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
}
|
||||
sample_size = sampler_->GetNumSamples();
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -216,6 +216,11 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "VOCOp"; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <iomanip>
|
||||
#include <utility>
|
||||
|
||||
#include <algorithm>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
|
@ -131,5 +132,18 @@ Status TakeOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<TakeOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TakeOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
|
||||
*dataset_size = std::min(static_cast<int64_t>(max_takes_), num_rows);
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,6 +88,11 @@ class TakeOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kTakeOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
int32_t max_takes_; // The number of takes that the user requested
|
||||
int32_t take_count_; // A counter for the current number of executed takes
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <iomanip>
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
|
@ -251,6 +252,24 @@ Status ZipOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<ZipOp>(), modified);
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ZipOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<int32_t> dataset_sizes;
|
||||
int64_t child_dataset_size;
|
||||
for (auto child : child_) {
|
||||
RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size));
|
||||
dataset_sizes.push_back(child_dataset_size);
|
||||
}
|
||||
|
||||
*dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end());
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ZipOp::ComputeColMap() {
|
||||
if (column_name_id_map_.empty()) {
|
||||
column_name_id_map_ = {};
|
||||
|
|
|
@ -120,6 +120,11 @@ class ZipOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kZipOp; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
private:
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status prepare(TensorQTable *const table);
|
||||
|
|
|
@ -47,6 +47,9 @@ class TreeAdapter {
|
|||
// 2. GetNext will return empty row when eoe/eof is obtained
|
||||
Status GetNext(TensorRow *);
|
||||
|
||||
// This function will return the root of the execution tree.
|
||||
std::weak_ptr<DatasetOp> GetRoot() { return tree_ != nullptr ? tree_->root() : nullptr; }
|
||||
|
||||
// This function will return the column_name_map once BuildAndPrepare() is called
|
||||
std::unordered_map<std::string, int32_t> GetColumnNameMap() const { return column_name_map_; }
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/include/iterator.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
|
@ -49,6 +50,7 @@ class DataSchema;
|
|||
class Tensor;
|
||||
class TensorShape;
|
||||
class TreeAdapter;
|
||||
class TreeGetters;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class Vocab;
|
||||
#endif
|
||||
|
@ -570,6 +572,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
/// \brief Gets the dataset size
|
||||
/// \return status code
|
||||
int64_t GetDatasetSize();
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
|
@ -750,6 +756,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
protected:
|
||||
std::vector<std::shared_ptr<Dataset>> children;
|
||||
std::shared_ptr<Dataset> parent;
|
||||
std::shared_ptr<TreeGetters> tree_getters_;
|
||||
|
||||
int32_t num_workers_;
|
||||
int32_t rows_per_buffer_;
|
||||
|
|
|
@ -73,6 +73,17 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10GetDatasetSize.";
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10000);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";
|
||||
|
||||
|
@ -108,6 +119,17 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar100GetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100GetDatasetSize.";
|
||||
|
||||
// Create a Cifar100 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, "all", RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1.";
|
||||
|
||||
|
|
|
@ -162,6 +162,19 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetBasic) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCLUEGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEGetDatasetSize.";
|
||||
|
||||
// Create a CLUEFile Dataset, with single CLUE file
|
||||
std::string clue_file = datasets_root_path_ + "/testCLUE/afqmc/train.json";
|
||||
std::string task = "AFQMC";
|
||||
std::string usage = "train";
|
||||
std::shared_ptr<Dataset> ds = CLUE({clue_file}, task, usage, 2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetCMNLI.";
|
||||
|
||||
|
|
|
@ -91,6 +91,18 @@ TEST_F(MindDataTestPipeline, TestCocoDefault) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCocoGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoGetDatasetSize.";
|
||||
// Create a Coco Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCOCO/train";
|
||||
std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json";
|
||||
|
||||
std::shared_ptr<Dataset> ds = Coco(folder_path, annotation_file);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 6);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCocoDetection) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoDetection.";
|
||||
// Create a Coco Dataset
|
||||
|
|
|
@ -101,6 +101,18 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCSVGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVGetDatasetSize.";
|
||||
|
||||
// Create a CSVDataset, with single CSV file
|
||||
std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetMultiFiles.";
|
||||
|
||||
|
|
|
@ -67,6 +67,17 @@ TEST_F(MindDataTestPipeline, TestManifestBasic) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestManifestGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestGetDatasetSize.";
|
||||
|
||||
std::string file_path = datasets_root_path_ + "/testManifestData/cpp.json";
|
||||
// Create a Manifest Dataset
|
||||
std::shared_ptr<Dataset> ds = Manifest(file_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestManifestDecode) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestDecode.";
|
||||
|
||||
|
@ -91,7 +102,7 @@ TEST_F(MindDataTestPipeline, TestManifestDecode) {
|
|||
auto shape = image->shape();
|
||||
MS_LOG(INFO) << "Tensor image shape size: " << shape.Size();
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect
|
||||
EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
|
|
|
@ -71,6 +71,19 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess1) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMindDataGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataGetDatasetSize with string file pattern.";
|
||||
|
||||
// Create a MindData Dataset
|
||||
// Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info,
|
||||
// thus all records in imagenet.mindrecord0 ~ imagenet.mindrecord3 will be read
|
||||
std::string file_path = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
|
||||
std::shared_ptr<Dataset> ds = MindData(file_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 20);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMindDataSuccess2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess2 with a vector of single mindrecord file.";
|
||||
|
||||
|
|
|
@ -368,6 +368,34 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestConcatGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatGetDatasetSize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
// Column names: {"image", "label"}
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
// Column names: {"image", "label"}
|
||||
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, "all", RandomSampler(false, 9));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
ds = ds->Project({"image"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds2 = ds2->Project({"image"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds = ds->Concat({ds2});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 19);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestConcatSuccess2) {
|
||||
// Test "+" operator to concat two datasets
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess2.";
|
||||
|
@ -461,6 +489,27 @@ TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestPipelineGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPipelineGetDatasetSize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 2;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 10);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestProjectMap) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectMap.";
|
||||
|
||||
|
@ -914,6 +963,22 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Skip operation on ds
|
||||
int32_t count = 3;
|
||||
ds = ds->Skip(count);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 7);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSkipDatasetError1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1.";
|
||||
|
||||
|
@ -966,6 +1031,21 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetDefault) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTakeGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeGetDatasetSize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 7));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Take operation on ds, dafault count = -1
|
||||
ds = ds->Take(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTakeDatasetError1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeDatasetError1.";
|
||||
|
||||
|
@ -1190,6 +1270,44 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestZipGetDatasetSize) {
|
||||
// Testing the member zip() function
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipGetDatasetSize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
std::vector<std::string> column_project = {"image"};
|
||||
ds = ds->Project(column_project);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, RandomSampler(false, 3));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create a Rename operation on ds (so that the 3 datasets we are going to zip have distinct column names)
|
||||
ds1 = ds1->Rename({"image", "label"}, {"col1", "col2"});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, "all", RandomSampler(false, 5));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
column_project = {"label"};
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Zip operation on the datasets
|
||||
ds = ds->Zip({ds1, ds2});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestZipSuccess2) {
|
||||
// Testing the static zip() function
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipSuccess2.";
|
||||
|
|
|
@ -87,6 +87,19 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic1) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetGetDatasetSize.";
|
||||
|
||||
// Create a RandomDataset
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2});
|
||||
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 50);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic2.";
|
||||
|
||||
|
|
|
@ -96,6 +96,32 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetBasic) {
|
|||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileGetDatasetSize.";
|
||||
// Test TextFile Dataset with single text file and many default inputs
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(987);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a TextFile Dataset, with single text file
|
||||
// Note: 1.txt has 3 rows
|
||||
// Use 2 samples
|
||||
// Use defaults for other input parameters
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail1.";
|
||||
|
||||
|
|
|
@ -98,6 +98,36 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasicGetDatasetSize.";
|
||||
|
||||
// Create a TFRecord Dataset
|
||||
std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
|
||||
std::string schema_path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json";
|
||||
std::shared_ptr<Dataset> ds = TFRecord({file_path}, schema_path, {"image"}, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5);
|
||||
EXPECT_NE(random_horizontal_flip_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 1;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 6);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle.";
|
||||
// This case is to verify if the list of datafiles are sorted in lexicographical order.
|
||||
|
|
|
@ -86,6 +86,22 @@ TEST_F(MindDataTestPipeline, TestVOCClassIndex) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize.";
|
||||
|
||||
// Create a VOC Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testVOC2012_2";
|
||||
std::map<std::string, int32_t> class_index;
|
||||
class_index["car"] = 0;
|
||||
class_index["cat"] = 1;
|
||||
class_index["train"] = 9;
|
||||
|
||||
std::shared_ptr<Dataset> ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 6);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVOCDetection) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCDetection.";
|
||||
|
||||
|
|
|
@ -125,6 +125,17 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCelebAGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAGetDatasetSize.";
|
||||
|
||||
// Create a CelebA Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCelebAData/";
|
||||
std::shared_ptr<Dataset> ds = CelebA(folder_path, "valid");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 1);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCelebAException) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAException.";
|
||||
|
||||
|
@ -179,6 +190,17 @@ TEST_F(MindDataTestPipeline, TestImageFolderFailWithWrongExtension) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestImageFolderGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderGetDatasetSize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 44);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestImageFolderFailWithNullSampler) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderFailWithNullSampler.";
|
||||
|
||||
|
@ -199,6 +221,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderFailWithWrongSampler) {
|
|||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMnistGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistGetDatasetSize.";
|
||||
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, "all", RandomSampler(false, 20));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 20);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir.";
|
||||
|
||||
|
|
Loading…
Reference in New Issue