!7710 Add c++ API for GetRepeatCount and GetBatchSize

Merge pull request !7710 from h.farahat/getBatchRepeat
This commit is contained in:
mindspore-ci-bot 2020-10-26 09:25:21 +08:00 committed by Gitee
commit d81bd7b17b
13 changed files with 136 additions and 7 deletions

View File

@ -655,6 +655,42 @@ Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}
return Status::OK();
}
int64_t Dataset::GetBatchSize() {
int64_t batch_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetBatchSize(&batch_size);
return rc.IsError() ? -1 : batch_size;
}
int64_t Dataset::GetRepeatCount() {
int64_t repeat_count;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetRepeatCount(&repeat_count);
return rc.IsError() ? 0 : repeat_count;
}
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}

View File

@ -430,4 +430,18 @@ Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
}
return Status::OK();
}
Status TreeGetters::GetBatchSize(int64_t *batch_size) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
*batch_size = root->GetTreeBatchSize();
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
return Status::OK();
}
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
*repeat_count = root->GetTreeRepeatCount();
return Status::OK();
}
} // namespace mindspore::dataset

View File

@ -162,6 +162,8 @@ class TreeGetters : public TreeConsumer {
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
Status GetBatchSize(int64_t *batch_size);
Status GetRepeatCount(int64_t *repeat_count);
bool isInitialized();
std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r);

View File

@ -555,6 +555,14 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
dataset_size_ = num_rows;
return Status::OK();
}
int64_t BatchOp::GetTreeBatchSize() {
#ifdef ENABLE_PYTHON
if (batch_size_func_) {
return -1;
}
#endif
return start_batch_size_;
}
} // namespace dataset
} // namespace mindspore

View File

@ -224,6 +224,8 @@ class BatchOp : public ParallelOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
int64_t GetTreeBatchSize() override;
protected:
Status ComputeColMap() override;

View File

@ -455,5 +455,17 @@ void DatasetOp::UpdateRepeatAndEpochCounter() {
if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++;
MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_;
}
int64_t DatasetOp::GetTreeBatchSize() {
if (!child_.empty()) {
return child_[0]->GetTreeBatchSize();
}
return 1;
}
int64_t DatasetOp::GetTreeRepeatCount() {
if (!child_.empty()) {
return child_[0]->GetTreeRepeatCount();
}
return 1;
}
} // namespace dataset
} // namespace mindspore

View File

@ -183,6 +183,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The status code return
virtual Status GetDatasetSize(int64_t *dataset_size);
/// \brief Gets the batch size
/// \return Status - The status code return
virtual int64_t GetTreeBatchSize();
/// \brief Gets the repeat count
/// \return Status - The status code return
virtual int64_t GetTreeRepeatCount();
/// \brief Performs handling for when an eoe message is received.
/// The base class implementation simply flows the eoe message to output. Derived classes
/// may override if they need to perform special eoe handling.

View File

@ -119,5 +119,6 @@ Status EpochCtrlOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
int64_t EpochCtrlOp::GetTreeRepeatCount() { return child_[0]->GetTreeRepeatCount(); }
} // namespace dataset
} // namespace mindspore

View File

@ -76,6 +76,8 @@ class EpochCtrlOp : public RepeatOp {
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
int64_t GetTreeRepeatCount() override;
};
} // namespace dataset
} // namespace mindspore

View File

@ -207,5 +207,6 @@ Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
dataset_size_ = num_rows;
return Status::OK();
}
int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; }
} // namespace dataset
} // namespace mindspore

View File

@ -138,6 +138,8 @@ class RepeatOp : public PipelineOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
int64_t GetTreeRepeatCount() override;
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }

View File

@ -588,17 +588,25 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
}
/// \brief Gets the dataset size
/// \return status code
/// \return int64_t
int64_t GetDatasetSize();
/// \brief Gets the output type
/// \return status code
/// \return vector of DataType
std::vector<DataType> GetOutputTypes();
/// \brief Gets the output shape
/// \return status code
/// \return vector of TensorShapes
std::vector<TensorShape> GetOutputShapes();
/// \brief Gets the batch size
/// \return int64_t
int64_t GetBatchSize();
/// \brief Gets the the repeat count
/// \return int64_t
int64_t GetRepeatCount();
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
@ -668,16 +676,18 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// 0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
/// \param[in] bucket_batch_sizes A list consisting of the batch sizes for each bucket.
/// Must contain elements equal to the size of bucket_boundaries + 1.
/// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow. The output
/// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow. The
/// output
/// must contain a single tensor containing a single int32_t. If no value is provided, then size of column_names
/// must be 1, and the size of the first dimension of that column will be taken as the length (default=nullptr)
/// \param[in] pad_info Represents how to batch each column. The key corresponds to the column name, the value must
/// be a tuple of 2 elements. The first element corresponds to the shape to pad to, and the second element
/// corresponds to the value to pad with. If a column is not specified, then that column will be padded to the
/// longest in the current batch, and 0 will be used as the padding value. Any unspecified dimensions will be
/// padded to the longest in the current batch, unless if pad_to_bucket_boundary is true. If no padding is wanted,
/// set pad_info to None (default=empty dictionary).
/// \param[in] pad_to_bucket_boundary If true, will pad each unspecified dimension in pad_info to the bucket_boundary
/// padded to the longest in the current batch, unless if pad_to_bucket_boundary is true. If no padding is
/// wanted, set pad_info to None (default=empty dictionary).
/// \param[in] pad_to_bucket_boundary If true, will pad each unspecified dimension in pad_info to the
/// bucket_boundary
/// minus 1. If there are any elements that fall into the last bucket, an error will occur (default=false).
/// \param[in] drop_remainder If true, will drop the last batch for each bucket if it is not a full batch
/// (default=false).

View File

@ -125,6 +125,37 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestGetRepeatCount) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetRepeatCount.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetRepeatCount(), 1);
ds = ds->Repeat(4);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetRepeatCount(), 4);
ds = ds->Repeat(3);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetRepeatCount(), 3);
}
TEST_F(MindDataTestPipeline, TestGetBatchSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetRepeatCount.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true)->Project({"label"});
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetBatchSize(), 1);
ds = ds->Batch(2);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetBatchSize(), 2);
ds = ds->Batch(3);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetBatchSize(), 3);
}
TEST_F(MindDataTestPipeline, TestCelebAGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAGetDatasetSize.";