diff --git a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt index 7064ec144f0..884f294439e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt @@ -27,6 +27,7 @@ set(SRC_FILES_LIST consumers/pull_based_tree_consumer.cc consumers/tree_consumer.cc serdes.cc + tree_modifier.cc ) if(ENABLE_PYTHON) set(SRC_FILES_LIST diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index dd50302e632..cd1e6e1f390 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -116,6 +116,10 @@ class DatasetOp : public std::enable_shared_from_this { return Status(StatusCode::kMDUnexpectedError, "Add new workers is not supported for non-ParallelOps"); } + virtual Status RemoveWorkers(int32_t num_workers = 1) { + return Status(StatusCode::kMDUnexpectedError, "Remove workers is not supported for non-ParallelOps"); + } + // \brief Inserts a operator as the parent current op. // \notes Inserted op will become the sole parent of the current op. // The existing parent of the current op will be transferred to the inserted op. @@ -239,6 +243,8 @@ class DatasetOp : public std::enable_shared_from_this { // \return - the column name map as a string std::string ColumnNameMapAsString() const; + OperatorConnector *OutputConnector() const { return out_connector_.get(); } + // \brief Getter function // \return connector size of current op int32_t ConnectorSize() const { @@ -328,6 +334,10 @@ class DatasetOp : public std::enable_shared_from_this { // \return Status virtual Status WaitForWorkers() { return Status::OK(); } + virtual int32_t NumWorkers() { return 0; } + + virtual Status SendQuitFlagToWorker(int32_t worker_id) { return Status::OK(); } + // \brief Add callback to DatasetOp, only MapOp supports Callback at the moment void AddCallbacks(std::vector> callbacks) { callback_manager_.AddCallbacks(callbacks); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index cbaf0ff9e14..0bd022ddfcd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "minddata/dataset/engine/datasetops/map_op/map_op.h" +#include #include #include #include @@ -161,9 +162,7 @@ Status MapOp::operator()() { // Quit all workers, this code might never be reached if EpochCtrl is -1. for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { - TensorRow quit_flag(TensorRow::kFlagQuit); - auto quit = std::make_unique(quit_flag); - RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::move(quit))); + RETURN_IF_NOT_OK(SendQuitFlagToWorker(NextWorkerID())); } return Status::OK(); @@ -387,5 +386,11 @@ Status MapOp::WaitForWorkers() { wait_for_workers_post_.Clear(); return Status::OK(); } +Status MapOp::SendQuitFlagToWorker(int32_t worker_id) { + TensorRow quit_flag(TensorRow::kFlagQuit); + auto quit = std::make_unique(quit_flag); + RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::move(quit))); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h index bf74fdba157..0aa67c1edf1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h @@ -180,6 +180,11 @@ class MapOp : public ParallelOp, TensorRow> { // who does the increment wakes up the master. // @return - Status Status WaitForWorkers() override; + + /// Send quit flag row to worker at worker_id to make it exit + /// \param worker_id id of the worker + /// \return Status code + Status SendQuitFlagToWorker(int32_t worker_id) override; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h index 75884f85b30..9ae2386d10b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ #include +#include #include #include #include @@ -98,8 +99,9 @@ class ParallelOp : public DatasetOp { RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers( - num_workers_, std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry", id())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, + std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1), + &worker_tasks_, Name() + "::WorkerEntry", id())); RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ParallelOp::Collector, this), Name() + "::Collector", id())); return Status::OK(); @@ -154,8 +156,11 @@ class ParallelOp : public DatasetOp { for (int32_t i = 0; i < num_new_workers; i++) { worker_in_queues_.AddQueue(tree_->AllTasks()); worker_out_queues_.AddQueue(tree_->AllTasks()); + Task *new_task; RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask( - Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), nullptr, id())); + Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), &new_task, id())); + CHECK_FAIL_RETURN_UNEXPECTED(new_task != nullptr, "Cannot create a new worker."); + worker_tasks_.push_back(new_task); num_workers_++; MS_LOG(INFO) << "A new worker has been added to op: " << Name() << "::" << id() << " num_workers=" << num_workers_; @@ -163,6 +168,25 @@ class ParallelOp : public DatasetOp { return Status::OK(); } + /// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows. + /// Then it adds a new thread to the list. + /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to + /// push rows to workers_in_queue + /// \return Status The status code returned + Status RemoveWorkers(int32_t num_workers = 1) override { + // wait for workers to process the current rows + RETURN_IF_NOT_OK(WaitForWorkers()); + for (int32_t i = 0; i < num_workers; i++) { + RETURN_IF_NOT_OK(SendQuitFlagToWorker(num_workers_ - 1)); + RETURN_IF_NOT_OK(worker_tasks_[num_workers_ - 1]->Join()); + RETURN_IF_NOT_OK(worker_in_queues_.RemoveLastQueue()); + worker_tasks_.pop_back(); + num_workers_--; + MS_LOG(INFO) << "Worker ID " << num_workers_ << " is requested to be removed in operator: " << NameWithID() + << " num_workers=" << num_workers_; + } + return Status::OK(); + } // Wait post used to perform the pausing logic WaitPost wait_for_workers_post_; @@ -175,14 +199,22 @@ class ParallelOp : public DatasetOp { /// The number of worker threads int32_t num_workers_; + std::vector worker_tasks_; + int32_t NextWorkerID() { int32_t next_worker = next_worker_id_; next_worker_id_ = (next_worker_id_ + 1) % num_workers_; return next_worker; } + public: + int32_t NumWorkers() override { return num_workers_; } + + protected: std::atomic_int next_worker_id_; + std::map quit_ack_; + /// The size of input/output worker queeus int32_t worker_connector_size_; /// queues to hold the input rows to workers diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 61eaed9d3ae..c16ea15b323 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -218,10 +218,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr &root) : ind_ (void)nodes_.emplace_back(nullptr); } -// Given the number of workers, launches the worker entry function for each. Essentially a +// Given the number of workers, launch the worker entry function for each worker. This is essentially a // wrapper for the TaskGroup handling that is stored inside the execution tree. -Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func, std::string name, - int32_t operator_id) { +Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func, + std::vector *worker_tasks, std::string name, int32_t operator_id) { int32_t num_cpu_threads = GlobalContext::Instance()->config_manager()->num_cpu_threads(); // this performs check that num_workers is positive and not unreasonably large which could happen // for example, un-initialized variable. uint16 max is 65536 which is large enough to cover everything @@ -232,12 +232,24 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::functionresize(num_workers); for (int32_t i = 0; i < num_workers; ++i) { - RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i), nullptr, operator_id)); + Task *task = nullptr; + RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i), &task, operator_id)); + CHECK_FAIL_RETURN_UNEXPECTED(task != nullptr, "Failed to create a new worker"); + (*worker_tasks)[i] = task; } return Status::OK(); } +// Given the number of workers, launches the worker entry function for each. Essentially a +// wrapper for the TaskGroup handling that is stored inside the execution tree. +Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func, std::string name, + int32_t operator_id) { + std::vector tasks; + return LaunchWorkers(num_workers, func, &tasks, name, operator_id); +} + // Walks the tree to perform modifications to the tree in post-order to get it ready for execution. Status ExecutionTree::Prepare() { if (root_ == nullptr) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h index bde233f1369..fac835e74c5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -149,12 +149,14 @@ class ExecutionTree { /// wrapper for the TaskGroup handling that is stored inside the execution tree. /// \param num_workers - The number of workers to launch /// \param func - The function entry point that workers will execute + /// \param[out] worker_tasks - output vector to hold generated tasks /// \param name - The description of worker to launch /// \param op_id - The id of corresponding operator, if not inherit from dataset op then it is -1. /// \return Status The status code returned + Status LaunchWorkers(int32_t num_workers, std::function func, std::vector *worker_tasks, + std::string name = "", int32_t operator_id = -1); Status LaunchWorkers(int32_t num_workers, std::function func, std::string name = "", int32_t operator_id = -1); - /// \brief Getter method /// \return shared_ptr to the root operator std::shared_ptr root() const { return root_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index b82cd08fc22..823912cd602 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -30,11 +30,13 @@ namespace mindspore { namespace dataset { class DatasetNode; +class TreeModifier; class TreeAdapter { #ifndef ENABLE_SECURITY friend ProfilingManager; #endif + friend TreeModifier; public: // this flag is used to indicate the purpose of the creation of this tree adapter (type of the tree_consumer). @@ -90,7 +92,7 @@ class TreeAdapter { ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } #endif - private: + protected: // Run the mandatory pass checking the syntax and semantics of the IR tree Status PrePass(std::shared_ptr ir); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_modifier.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_modifier.cc new file mode 100644 index 00000000000..91a8de65ae4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_modifier.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/tree_modifier.h" + +namespace mindspore { +namespace dataset { + +Status AutotuneCallback::DSNStepBegin(const CallbackParam &cb_param) { + // check if the queue is empty, no need to wait until a change request is ready + if (!change_request_queue_->empty()) { + ChangeRequestPtr change_request; + RETURN_IF_NOT_OK(change_request_queue_->PopFront(&change_request)); + RETURN_IF_NOT_OK(change_request->ApplyChange(op_)); + } + return Status::OK(); +} + +Status AutotuneCallback::DSBegin(const CallbackParam &cb_param) { return Status::OK(); } + +Status AutotuneCallback::DSEpochBegin(const CallbackParam &cb_param) { return Status::OK(); } + +Status AutotuneCallback::DSEnd(const CallbackParam &cb_param) { return Status::OK(); } + +Status AutotuneCallback::DSEpochEnd(const CallbackParam &cb_param) { return Status::OK(); } + +Status AutotuneCallback::DSNStepEnd(const CallbackParam &cb_param) { return Status::OK(); } + +bool AutotuneCallback::IsBeginNeeded() { return false; } + +bool AutotuneCallback::IsEpochBeginNeeded() { return false; } + +bool AutotuneCallback::IsNStepBeginNeeded() { return true; } + +bool AutotuneCallback::IsEndNeeded() { return false; } + +bool AutotuneCallback::IsEpochEndNeeded() { return false; } + +bool AutotuneCallback::IsNStepEndNeeded() { return false; } + +Status AutotuneCallback::PushChangeRequest(ChangeRequestPtr change_request) { + RETURN_IF_NOT_OK(change_request_queue_->Add(change_request)); + return Status::OK(); +} + +Status ChangeNumWorkersRequest::ApplyChange(DatasetOp *op) { + int32_t diff = num_workers_ - op->NumWorkers(); + if (diff > 0) { + RETURN_IF_NOT_OK(op->AddNewWorkers(diff)); + } else if (diff < 0) { + RETURN_IF_NOT_OK(op->RemoveWorkers(-1 * diff)); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_modifier.h b/mindspore/ccsrc/minddata/dataset/engine/tree_modifier.h new file mode 100644 index 00000000000..29e73d04545 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_modifier.h @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/tree_adapter.h" + +namespace mindspore { +namespace dataset { +class DatasetNode; + +/// A pure virtual class to be used as a base for all pipeline modification requests. +class ChangeRequest { + public: + /// Default constructor + ChangeRequest() = default; + + /// Pure virtual method. Subclasses should override this function and implement the actual change to the give + /// operator. + /// \param op pointer to the operator that the change will be applied on + /// \return Status return Status code + virtual Status ApplyChange(DatasetOp *op) = 0; +}; + +using ChangeRequestPtr = std::shared_ptr; + +/// ChangeRequest to add n workers to an operator. +class ChangeNumWorkersRequest : public ChangeRequest { + public: + /// Constructor + /// \param num_workers number of workeres to be added to the opertor. Default to 1. + explicit ChangeNumWorkersRequest(int32_t num_workers = 1) : num_workers_(num_workers) {} + virtual ~ChangeNumWorkersRequest() = default; + + /// Actual change to add n workers + /// \param op pointer to the operator that the change will be applied on + /// \return Status return Status code + Status ApplyChange(DatasetOp *op) override; + + private: + int32_t num_workers_; +}; + +/// ChangeRequest to change the size of the oupout connector of an operator. +class ResizeConnectorRequest : public ChangeRequest { + public: + /// Constructor + /// \param new_size new queue size. + explicit ResizeConnectorRequest(int32_t new_size) : new_size_(new_size) {} + virtual ~ResizeConnectorRequest() = default; + + /// Actual change to resize the output connector of the given operator + /// \param op pointer to the operator that the change will be applied on + /// \return Status return Status code + Status ApplyChange(DatasetOp *op) override { + RETURN_IF_NOT_OK(op->OutputConnector()->Resize(new_size_)); + return Status::OK(); + } + + private: + int32_t new_size_; +}; + +/// A callback class used by Aututune to queue changes for opertors +class AutotuneCallback : public DSCallback { + public: + AutotuneCallback(int32_t step_size, DatasetOp *op) + : DSCallback(step_size), op_(op), change_request_queue_(std::make_unique>(10)) {} + virtual ~AutotuneCallback() = default; + + Status DSNStepBegin(const CallbackParam &cb_param) override; + Status DSBegin(const CallbackParam &cb_param) override; + Status DSEpochBegin(const CallbackParam &cb_param) override; + Status DSEnd(const CallbackParam &cb_param) override; + Status DSEpochEnd(const CallbackParam &cb_param) override; + Status DSNStepEnd(const CallbackParam &cb_param) override; + + bool IsBeginNeeded() override; + bool IsEpochBeginNeeded() override; + bool IsNStepBeginNeeded() override; + bool IsEndNeeded() override; + bool IsEpochEndNeeded() override; + bool IsNStepEndNeeded() override; + + /// Push a change request to the queue of the callback. + /// \param change_request Shared pointer to the change request to be pushed to the queue. + /// \return Status return Status code + Status PushChangeRequest(ChangeRequestPtr change_request); + + private: + DatasetOp *op_; + std::unique_ptr> change_request_queue_; +}; + +/// Main class to handle modification of the ExecutionTree used by AutoTune +class TreeModifier { + // friend with TreeAdapter to access the ExeecutionTree + friend TreeAdapter; + + public: + /// Constructor to create a TreeModifier given a TreeAdapter + /// \param adapter TreeAdapter + explicit TreeModifier(TreeAdapter *adapter) : TreeModifier(adapter->tree_.get()) {} + + /// Constructor to create a TreeModifier given an ExecutionTree + /// \param tree ExecutionTree + explicit TreeModifier(ExecutionTree *tree) : tree_(tree) { + // loop over all ops to create AutotuneCallback and register it. + for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { + auto cb = std::make_shared(1, itr.get().get()); + itr->AddCallbacks({cb}); + callbacks.insert(std::make_pair(itr->id(), cb)); + } + } + + /// Add changeRequest to the callback associated with the op. + /// \param op_id Operator ID + /// \param change_request Pointer to the change request + /// \return Status return Status code + Status AddChangeRequest(int32_t op_id, ChangeRequestPtr change_request) { + RETURN_IF_NOT_OK(callbacks[op_id]->PushChangeRequest(change_request)); + return Status::OK(); + } + + private: + ExecutionTree *tree_; + std::map> callbacks; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/queue.h b/mindspore/ccsrc/minddata/dataset/util/queue.h index 94dddec4aba..a0d9d31e4f0 100644 --- a/mindspore/ccsrc/minddata/dataset/util/queue.h +++ b/mindspore/ccsrc/minddata/dataset/util/queue.h @@ -271,6 +271,11 @@ class QueueList { queue_list_.emplace_back(std::make_unique>(queue_list_[0]->capacity())); return queue_list_[queue_list_.size() - 1]->Register(vg); } + Status RemoveLastQueue() { + CHECK_FAIL_RETURN_UNEXPECTED(queue_list_.size() > 1, "Cannot remove more than the current queues."); + queue_list_.pop_back(); + return Status::OK(); + } private: // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector diff --git a/tests/ut/cpp/dataset/ir_tree_adapter_test.cc b/tests/ut/cpp/dataset/ir_tree_adapter_test.cc index af89a186f06..9edd5aaf530 100644 --- a/tests/ut/cpp/dataset/ir_tree_adapter_test.cc +++ b/tests/ut/cpp/dataset/ir_tree_adapter_test.cc @@ -31,6 +31,8 @@ #include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h" +#include "minddata/dataset/engine/tree_modifier.h" + using namespace mindspore::dataset; using mindspore::dataset::Tensor; @@ -145,3 +147,54 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { const std::string err_msg = rc.ToString(); EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); } + +// Feature: Basic test for TreeModifier +// Description: Create simple tree and modify the tree by adding workers, change queue size and then removing workers +// Expectation: No failures. +TEST_F(MindDataTestTreeAdapter, TestSimpleTreeModifier) { + MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestSimpleTreeModifier."; + + // Create a CSVDataset, with single CSV file + std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; + std::vector column_names = {"col1", "col2", "col3", "col4"}; + std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse); + ASSERT_NE(ds, nullptr); + ds = ds->Project({"col1"}); + ASSERT_NE(ds, nullptr); + ds = ds->Repeat(2); + ASSERT_NE(ds, nullptr); + auto to_number = std::make_shared(mindspore::DataType::kNumberTypeInt32); + ASSERT_NE(to_number, nullptr); + ds = ds->Map({to_number}, {"col1"}, {"col1"}); + ds->SetNumWorkers(1); + + auto tree_adapter = std::make_shared(); + // Disable IR optimization pass + tree_adapter->SetOptimize(false); + ASSERT_OK(tree_adapter->Compile(ds->IRNode(), 1)); + + auto tree_modifier = std::make_unique(tree_adapter.get()); + tree_modifier->AddChangeRequest(0, std::make_shared(2)); + tree_modifier->AddChangeRequest(0, std::make_shared(20)); + tree_modifier->AddChangeRequest(0, std::make_shared()); + tree_modifier->AddChangeRequest(0, std::make_shared(100)); + tree_modifier->AddChangeRequest(0, std::make_shared(10)); + + std::vector expected_result = {1, 5, 9, 1, 5, 9}; + TensorRow row; + + uint64_t i = 0; + ASSERT_OK(tree_adapter->GetNext(&row)); + + while (row.size() != 0) { + auto tensor = row[0]; + int32_t num; + ASSERT_OK(tensor->GetItemAt(&num, {})); + EXPECT_EQ(num, expected_result[i]); + ASSERT_OK(tree_adapter->GetNext(&row)); + i++; + } + + // Expect 6 samples + EXPECT_EQ(i, 6); +} \ No newline at end of file