forked from mindspore-Ecosystem/mindspore
!7710 Add c++ API for GetRepeatCount and GetBatchSize
Merge pull request !7710 from h.farahat/getBatchRepeat
This commit is contained in:
commit
d81bd7b17b
|
@ -655,6 +655,42 @@ Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t Dataset::GetBatchSize() {
|
||||
int64_t batch_size;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetBatchSize(&batch_size);
|
||||
return rc.IsError() ? -1 : batch_size;
|
||||
}
|
||||
int64_t Dataset::GetRepeatCount() {
|
||||
int64_t repeat_count;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetRepeatCount(&repeat_count);
|
||||
return rc.IsError() ? 0 : repeat_count;
|
||||
}
|
||||
|
||||
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
|
||||
|
||||
|
|
|
@ -430,4 +430,18 @@ Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetBatchSize(int64_t *batch_size) {
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
*batch_size = root->GetTreeBatchSize();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
|
||||
return Status::OK();
|
||||
}
|
||||
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
*repeat_count = root->GetTreeRepeatCount();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -162,6 +162,8 @@ class TreeGetters : public TreeConsumer {
|
|||
Status GetDatasetSize(int64_t *size);
|
||||
Status GetOutputTypes(std::vector<DataType> *types);
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes);
|
||||
Status GetBatchSize(int64_t *batch_size);
|
||||
Status GetRepeatCount(int64_t *repeat_count);
|
||||
bool isInitialized();
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
Status GetRow(TensorRow *r);
|
||||
|
|
|
@ -555,6 +555,14 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t BatchOp::GetTreeBatchSize() {
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (batch_size_func_) {
|
||||
return -1;
|
||||
}
|
||||
#endif
|
||||
return start_batch_size_;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -224,6 +224,8 @@ class BatchOp : public ParallelOp {
|
|||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
int64_t GetTreeBatchSize() override;
|
||||
|
||||
protected:
|
||||
Status ComputeColMap() override;
|
||||
|
||||
|
|
|
@ -455,5 +455,17 @@ void DatasetOp::UpdateRepeatAndEpochCounter() {
|
|||
if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++;
|
||||
MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_;
|
||||
}
|
||||
int64_t DatasetOp::GetTreeBatchSize() {
|
||||
if (!child_.empty()) {
|
||||
return child_[0]->GetTreeBatchSize();
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
int64_t DatasetOp::GetTreeRepeatCount() {
|
||||
if (!child_.empty()) {
|
||||
return child_[0]->GetTreeRepeatCount();
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -183,6 +183,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Status - The status code return
|
||||
virtual Status GetDatasetSize(int64_t *dataset_size);
|
||||
|
||||
/// \brief Gets the batch size
|
||||
/// \return Status - The status code return
|
||||
virtual int64_t GetTreeBatchSize();
|
||||
|
||||
/// \brief Gets the repeat count
|
||||
/// \return Status - The status code return
|
||||
virtual int64_t GetTreeRepeatCount();
|
||||
|
||||
/// \brief Performs handling for when an eoe message is received.
|
||||
/// The base class implementation simply flows the eoe message to output. Derived classes
|
||||
/// may override if they need to perform special eoe handling.
|
||||
|
|
|
@ -119,5 +119,6 @@ Status EpochCtrlOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call the pre-visitation
|
||||
return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified);
|
||||
}
|
||||
int64_t EpochCtrlOp::GetTreeRepeatCount() { return child_[0]->GetTreeRepeatCount(); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,8 @@ class EpochCtrlOp : public RepeatOp {
|
|||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
int64_t GetTreeRepeatCount() override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -207,5 +207,6 @@ Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
dataset_size_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -138,6 +138,8 @@ class RepeatOp : public PipelineOp {
|
|||
/// \return Status of the function
|
||||
Status GetDatasetSize(int64_t *dataset_size) override;
|
||||
|
||||
int64_t GetTreeRepeatCount() override;
|
||||
|
||||
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
|
||||
|
|
|
@ -588,17 +588,25 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
}
|
||||
|
||||
/// \brief Gets the dataset size
|
||||
/// \return status code
|
||||
/// \return int64_t
|
||||
int64_t GetDatasetSize();
|
||||
|
||||
/// \brief Gets the output type
|
||||
/// \return status code
|
||||
/// \return vector of DataType
|
||||
std::vector<DataType> GetOutputTypes();
|
||||
|
||||
/// \brief Gets the output shape
|
||||
/// \return status code
|
||||
/// \return vector of TensorShapes
|
||||
std::vector<TensorShape> GetOutputShapes();
|
||||
|
||||
/// \brief Gets the batch size
|
||||
/// \return int64_t
|
||||
int64_t GetBatchSize();
|
||||
|
||||
/// \brief Gets the the repeat count
|
||||
/// \return int64_t
|
||||
int64_t GetRepeatCount();
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
|
@ -668,16 +676,18 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// 0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
|
||||
/// \param[in] bucket_batch_sizes A list consisting of the batch sizes for each bucket.
|
||||
/// Must contain elements equal to the size of bucket_boundaries + 1.
|
||||
/// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow. The output
|
||||
/// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow. The
|
||||
/// output
|
||||
/// must contain a single tensor containing a single int32_t. If no value is provided, then size of column_names
|
||||
/// must be 1, and the size of the first dimension of that column will be taken as the length (default=nullptr)
|
||||
/// \param[in] pad_info Represents how to batch each column. The key corresponds to the column name, the value must
|
||||
/// be a tuple of 2 elements. The first element corresponds to the shape to pad to, and the second element
|
||||
/// corresponds to the value to pad with. If a column is not specified, then that column will be padded to the
|
||||
/// longest in the current batch, and 0 will be used as the padding value. Any unspecified dimensions will be
|
||||
/// padded to the longest in the current batch, unless if pad_to_bucket_boundary is true. If no padding is wanted,
|
||||
/// set pad_info to None (default=empty dictionary).
|
||||
/// \param[in] pad_to_bucket_boundary If true, will pad each unspecified dimension in pad_info to the bucket_boundary
|
||||
/// padded to the longest in the current batch, unless if pad_to_bucket_boundary is true. If no padding is
|
||||
/// wanted, set pad_info to None (default=empty dictionary).
|
||||
/// \param[in] pad_to_bucket_boundary If true, will pad each unspecified dimension in pad_info to the
|
||||
/// bucket_boundary
|
||||
/// minus 1. If there are any elements that fall into the last bucket, an error will occur (default=false).
|
||||
/// \param[in] drop_remainder If true, will drop the last batch for each bucket if it is not a full batch
|
||||
/// (default=false).
|
||||
|
|
|
@ -125,6 +125,37 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestGetRepeatCount) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetRepeatCount.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
ds = ds->Repeat(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 4);
|
||||
ds = ds->Repeat(3);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestGetBatchSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetRepeatCount.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true)->Project({"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
ds = ds->Batch(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 2);
|
||||
ds = ds->Batch(3);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 3);
|
||||
}
|
||||
TEST_F(MindDataTestPipeline, TestCelebAGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAGetDatasetSize.";
|
||||
|
||||
|
|
Loading…
Reference in New Issue