forked from mindspore-Ecosystem/mindspore
!25524 [Autotune] Add TreeModifier and Remove Worker logic
Merge pull request !25524 from h.farahat/tree_modifier
This commit is contained in:
commit
eff2318bb6
|
@ -27,6 +27,7 @@ set(SRC_FILES_LIST
|
|||
consumers/pull_based_tree_consumer.cc
|
||||
consumers/tree_consumer.cc
|
||||
serdes.cc
|
||||
tree_modifier.cc
|
||||
)
|
||||
if(ENABLE_PYTHON)
|
||||
set(SRC_FILES_LIST
|
||||
|
|
|
@ -116,6 +116,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
return Status(StatusCode::kMDUnexpectedError, "Add new workers is not supported for non-ParallelOps");
|
||||
}
|
||||
|
||||
virtual Status RemoveWorkers(int32_t num_workers = 1) {
|
||||
return Status(StatusCode::kMDUnexpectedError, "Remove workers is not supported for non-ParallelOps");
|
||||
}
|
||||
|
||||
// \brief Inserts a operator as the parent current op.
|
||||
// \notes Inserted op will become the sole parent of the current op.
|
||||
// The existing parent of the current op will be transferred to the inserted op.
|
||||
|
@ -239,6 +243,8 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// \return - the column name map as a string
|
||||
std::string ColumnNameMapAsString() const;
|
||||
|
||||
OperatorConnector *OutputConnector() const { return out_connector_.get(); }
|
||||
|
||||
// \brief Getter function
|
||||
// \return connector size of current op
|
||||
int32_t ConnectorSize() const {
|
||||
|
@ -328,6 +334,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// \return Status
|
||||
virtual Status WaitForWorkers() { return Status::OK(); }
|
||||
|
||||
virtual int32_t NumWorkers() { return 0; }
|
||||
|
||||
virtual Status SendQuitFlagToWorker(int32_t worker_id) { return Status::OK(); }
|
||||
|
||||
// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
|
||||
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -161,9 +162,7 @@ Status MapOp::operator()() {
|
|||
|
||||
// Quit all workers, this code might never be reached if EpochCtrl is -1.
|
||||
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
TensorRow quit_flag(TensorRow::kFlagQuit);
|
||||
auto quit = std::make_unique<MapWorkerJob>(quit_flag);
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::move(quit)));
|
||||
RETURN_IF_NOT_OK(SendQuitFlagToWorker(NextWorkerID()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -387,5 +386,11 @@ Status MapOp::WaitForWorkers() {
|
|||
wait_for_workers_post_.Clear();
|
||||
return Status::OK();
|
||||
}
|
||||
Status MapOp::SendQuitFlagToWorker(int32_t worker_id) {
|
||||
TensorRow quit_flag(TensorRow::kFlagQuit);
|
||||
auto quit = std::make_unique<MapWorkerJob>(quit_flag);
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::move(quit)));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -180,6 +180,11 @@ class MapOp : public ParallelOp<std::unique_ptr<MapWorkerJob>, TensorRow> {
|
|||
// who does the increment wakes up the master.
|
||||
// @return - Status
|
||||
Status WaitForWorkers() override;
|
||||
|
||||
/// Send quit flag row to worker at worker_id to make it exit
|
||||
/// \param worker_id id of the worker
|
||||
/// \return Status code
|
||||
Status SendQuitFlagToWorker(int32_t worker_id) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -98,8 +99,9 @@ class ParallelOp : public DatasetOp {
|
|||
RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
|
||||
num_workers_, std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry", id()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
|
||||
std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1),
|
||||
&worker_tasks_, Name() + "::WorkerEntry", id()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ParallelOp::Collector, this), Name() + "::Collector", id()));
|
||||
|
||||
return Status::OK();
|
||||
|
@ -154,8 +156,11 @@ class ParallelOp : public DatasetOp {
|
|||
for (int32_t i = 0; i < num_new_workers; i++) {
|
||||
worker_in_queues_.AddQueue(tree_->AllTasks());
|
||||
worker_out_queues_.AddQueue(tree_->AllTasks());
|
||||
Task *new_task;
|
||||
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
|
||||
Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), nullptr, id()));
|
||||
Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), &new_task, id()));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(new_task != nullptr, "Cannot create a new worker.");
|
||||
worker_tasks_.push_back(new_task);
|
||||
num_workers_++;
|
||||
MS_LOG(INFO) << "A new worker has been added to op: " << Name() << "::" << id()
|
||||
<< " num_workers=" << num_workers_;
|
||||
|
@ -163,6 +168,25 @@ class ParallelOp : public DatasetOp {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows.
|
||||
/// Then it adds a new thread to the list.
|
||||
/// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to
|
||||
/// push rows to workers_in_queue
|
||||
/// \return Status The status code returned
|
||||
Status RemoveWorkers(int32_t num_workers = 1) override {
|
||||
// wait for workers to process the current rows
|
||||
RETURN_IF_NOT_OK(WaitForWorkers());
|
||||
for (int32_t i = 0; i < num_workers; i++) {
|
||||
RETURN_IF_NOT_OK(SendQuitFlagToWorker(num_workers_ - 1));
|
||||
RETURN_IF_NOT_OK(worker_tasks_[num_workers_ - 1]->Join());
|
||||
RETURN_IF_NOT_OK(worker_in_queues_.RemoveLastQueue());
|
||||
worker_tasks_.pop_back();
|
||||
num_workers_--;
|
||||
MS_LOG(INFO) << "Worker ID " << num_workers_ << " is requested to be removed in operator: " << NameWithID()
|
||||
<< " num_workers=" << num_workers_;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Wait post used to perform the pausing logic
|
||||
WaitPost wait_for_workers_post_;
|
||||
|
||||
|
@ -175,14 +199,22 @@ class ParallelOp : public DatasetOp {
|
|||
/// The number of worker threads
|
||||
int32_t num_workers_;
|
||||
|
||||
std::vector<Task *> worker_tasks_;
|
||||
|
||||
int32_t NextWorkerID() {
|
||||
int32_t next_worker = next_worker_id_;
|
||||
next_worker_id_ = (next_worker_id_ + 1) % num_workers_;
|
||||
return next_worker;
|
||||
}
|
||||
|
||||
public:
|
||||
int32_t NumWorkers() override { return num_workers_; }
|
||||
|
||||
protected:
|
||||
std::atomic_int next_worker_id_;
|
||||
|
||||
std::map<int32_t, std::atomic_bool> quit_ack_;
|
||||
|
||||
/// The size of input/output worker queeus
|
||||
int32_t worker_connector_size_;
|
||||
/// queues to hold the input rows to workers
|
||||
|
|
|
@ -218,10 +218,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_
|
|||
(void)nodes_.emplace_back(nullptr);
|
||||
}
|
||||
|
||||
// Given the number of workers, launches the worker entry function for each. Essentially a
|
||||
// Given the number of workers, launch the worker entry function for each worker. This is essentially a
|
||||
// wrapper for the TaskGroup handling that is stored inside the execution tree.
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name,
|
||||
int32_t operator_id) {
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func,
|
||||
std::vector<Task *> *worker_tasks, std::string name, int32_t operator_id) {
|
||||
int32_t num_cpu_threads = GlobalContext::Instance()->config_manager()->num_cpu_threads();
|
||||
// this performs check that num_workers is positive and not unreasonably large which could happen
|
||||
// for example, un-initialized variable. uint16 max is 65536 which is large enough to cover everything
|
||||
|
@ -232,12 +232,24 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
|||
MS_LOG(WARNING) << name + " is launched with " << std::to_string(num_workers) << " worker threads which exceeds "
|
||||
<< std::to_string(num_cpu_threads) << ", the maximum number of threads on this CPU.";
|
||||
}
|
||||
worker_tasks->resize(num_workers);
|
||||
for (int32_t i = 0; i < num_workers; ++i) {
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i), nullptr, operator_id));
|
||||
Task *task = nullptr;
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i), &task, operator_id));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(task != nullptr, "Failed to create a new worker");
|
||||
(*worker_tasks)[i] = task;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Given the number of workers, launches the worker entry function for each. Essentially a
|
||||
// wrapper for the TaskGroup handling that is stored inside the execution tree.
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name,
|
||||
int32_t operator_id) {
|
||||
std::vector<Task *> tasks;
|
||||
return LaunchWorkers(num_workers, func, &tasks, name, operator_id);
|
||||
}
|
||||
|
||||
// Walks the tree to perform modifications to the tree in post-order to get it ready for execution.
|
||||
Status ExecutionTree::Prepare() {
|
||||
if (root_ == nullptr) {
|
||||
|
|
|
@ -149,12 +149,14 @@ class ExecutionTree {
|
|||
/// wrapper for the TaskGroup handling that is stored inside the execution tree.
|
||||
/// \param num_workers - The number of workers to launch
|
||||
/// \param func - The function entry point that workers will execute
|
||||
/// \param[out] worker_tasks - output vector to hold generated tasks
|
||||
/// \param name - The description of worker to launch
|
||||
/// \param op_id - The id of corresponding operator, if not inherit from dataset op then it is -1.
|
||||
/// \return Status The status code returned
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::vector<Task *> *worker_tasks,
|
||||
std::string name = "", int32_t operator_id = -1);
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "",
|
||||
int32_t operator_id = -1);
|
||||
|
||||
/// \brief Getter method
|
||||
/// \return shared_ptr to the root operator
|
||||
std::shared_ptr<DatasetOp> root() const { return root_; }
|
||||
|
|
|
@ -30,11 +30,13 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DatasetNode;
|
||||
class TreeModifier;
|
||||
|
||||
class TreeAdapter {
|
||||
#ifndef ENABLE_SECURITY
|
||||
friend ProfilingManager;
|
||||
#endif
|
||||
friend TreeModifier;
|
||||
|
||||
public:
|
||||
// this flag is used to indicate the purpose of the creation of this tree adapter (type of the tree_consumer).
|
||||
|
@ -90,7 +92,7 @@ class TreeAdapter {
|
|||
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
|
||||
#endif
|
||||
|
||||
private:
|
||||
protected:
|
||||
// Run the mandatory pass checking the syntax and semantics of the IR tree
|
||||
Status PrePass(std::shared_ptr<DatasetNode> ir);
|
||||
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/tree_modifier.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status AutotuneCallback::DSNStepBegin(const CallbackParam &cb_param) {
|
||||
// check if the queue is empty, no need to wait until a change request is ready
|
||||
if (!change_request_queue_->empty()) {
|
||||
ChangeRequestPtr change_request;
|
||||
RETURN_IF_NOT_OK(change_request_queue_->PopFront(&change_request));
|
||||
RETURN_IF_NOT_OK(change_request->ApplyChange(op_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutotuneCallback::DSBegin(const CallbackParam &cb_param) { return Status::OK(); }
|
||||
|
||||
Status AutotuneCallback::DSEpochBegin(const CallbackParam &cb_param) { return Status::OK(); }
|
||||
|
||||
Status AutotuneCallback::DSEnd(const CallbackParam &cb_param) { return Status::OK(); }
|
||||
|
||||
Status AutotuneCallback::DSEpochEnd(const CallbackParam &cb_param) { return Status::OK(); }
|
||||
|
||||
Status AutotuneCallback::DSNStepEnd(const CallbackParam &cb_param) { return Status::OK(); }
|
||||
|
||||
bool AutotuneCallback::IsBeginNeeded() { return false; }
|
||||
|
||||
bool AutotuneCallback::IsEpochBeginNeeded() { return false; }
|
||||
|
||||
bool AutotuneCallback::IsNStepBeginNeeded() { return true; }
|
||||
|
||||
bool AutotuneCallback::IsEndNeeded() { return false; }
|
||||
|
||||
bool AutotuneCallback::IsEpochEndNeeded() { return false; }
|
||||
|
||||
bool AutotuneCallback::IsNStepEndNeeded() { return false; }
|
||||
|
||||
Status AutotuneCallback::PushChangeRequest(ChangeRequestPtr change_request) {
|
||||
RETURN_IF_NOT_OK(change_request_queue_->Add(change_request));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ChangeNumWorkersRequest::ApplyChange(DatasetOp *op) {
|
||||
int32_t diff = num_workers_ - op->NumWorkers();
|
||||
if (diff > 0) {
|
||||
RETURN_IF_NOT_OK(op->AddNewWorkers(diff));
|
||||
} else if (diff < 0) {
|
||||
RETURN_IF_NOT_OK(op->RemoveWorkers(-1 * diff));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DatasetNode;
|
||||
|
||||
/// A pure virtual class to be used as a base for all pipeline modification requests.
|
||||
class ChangeRequest {
|
||||
public:
|
||||
/// Default constructor
|
||||
ChangeRequest() = default;
|
||||
|
||||
/// Pure virtual method. Subclasses should override this function and implement the actual change to the give
|
||||
/// operator.
|
||||
/// \param op pointer to the operator that the change will be applied on
|
||||
/// \return Status return Status code
|
||||
virtual Status ApplyChange(DatasetOp *op) = 0;
|
||||
};
|
||||
|
||||
using ChangeRequestPtr = std::shared_ptr<ChangeRequest>;
|
||||
|
||||
/// ChangeRequest to add n workers to an operator.
|
||||
class ChangeNumWorkersRequest : public ChangeRequest {
|
||||
public:
|
||||
/// Constructor
|
||||
/// \param num_workers number of workeres to be added to the opertor. Default to 1.
|
||||
explicit ChangeNumWorkersRequest(int32_t num_workers = 1) : num_workers_(num_workers) {}
|
||||
virtual ~ChangeNumWorkersRequest() = default;
|
||||
|
||||
/// Actual change to add n workers
|
||||
/// \param op pointer to the operator that the change will be applied on
|
||||
/// \return Status return Status code
|
||||
Status ApplyChange(DatasetOp *op) override;
|
||||
|
||||
private:
|
||||
int32_t num_workers_;
|
||||
};
|
||||
|
||||
/// ChangeRequest to change the size of the oupout connector of an operator.
|
||||
class ResizeConnectorRequest : public ChangeRequest {
|
||||
public:
|
||||
/// Constructor
|
||||
/// \param new_size new queue size.
|
||||
explicit ResizeConnectorRequest(int32_t new_size) : new_size_(new_size) {}
|
||||
virtual ~ResizeConnectorRequest() = default;
|
||||
|
||||
/// Actual change to resize the output connector of the given operator
|
||||
/// \param op pointer to the operator that the change will be applied on
|
||||
/// \return Status return Status code
|
||||
Status ApplyChange(DatasetOp *op) override {
|
||||
RETURN_IF_NOT_OK(op->OutputConnector()->Resize(new_size_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t new_size_;
|
||||
};
|
||||
|
||||
/// A callback class used by Aututune to queue changes for opertors
|
||||
class AutotuneCallback : public DSCallback {
|
||||
public:
|
||||
AutotuneCallback(int32_t step_size, DatasetOp *op)
|
||||
: DSCallback(step_size), op_(op), change_request_queue_(std::make_unique<Queue<ChangeRequestPtr>>(10)) {}
|
||||
virtual ~AutotuneCallback() = default;
|
||||
|
||||
Status DSNStepBegin(const CallbackParam &cb_param) override;
|
||||
Status DSBegin(const CallbackParam &cb_param) override;
|
||||
Status DSEpochBegin(const CallbackParam &cb_param) override;
|
||||
Status DSEnd(const CallbackParam &cb_param) override;
|
||||
Status DSEpochEnd(const CallbackParam &cb_param) override;
|
||||
Status DSNStepEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
bool IsBeginNeeded() override;
|
||||
bool IsEpochBeginNeeded() override;
|
||||
bool IsNStepBeginNeeded() override;
|
||||
bool IsEndNeeded() override;
|
||||
bool IsEpochEndNeeded() override;
|
||||
bool IsNStepEndNeeded() override;
|
||||
|
||||
/// Push a change request to the queue of the callback.
|
||||
/// \param change_request Shared pointer to the change request to be pushed to the queue.
|
||||
/// \return Status return Status code
|
||||
Status PushChangeRequest(ChangeRequestPtr change_request);
|
||||
|
||||
private:
|
||||
DatasetOp *op_;
|
||||
std::unique_ptr<Queue<ChangeRequestPtr>> change_request_queue_;
|
||||
};
|
||||
|
||||
/// Main class to handle modification of the ExecutionTree used by AutoTune
|
||||
class TreeModifier {
|
||||
// friend with TreeAdapter to access the ExeecutionTree
|
||||
friend TreeAdapter;
|
||||
|
||||
public:
|
||||
/// Constructor to create a TreeModifier given a TreeAdapter
|
||||
/// \param adapter TreeAdapter
|
||||
explicit TreeModifier(TreeAdapter *adapter) : TreeModifier(adapter->tree_.get()) {}
|
||||
|
||||
/// Constructor to create a TreeModifier given an ExecutionTree
|
||||
/// \param tree ExecutionTree
|
||||
explicit TreeModifier(ExecutionTree *tree) : tree_(tree) {
|
||||
// loop over all ops to create AutotuneCallback and register it.
|
||||
for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) {
|
||||
auto cb = std::make_shared<AutotuneCallback>(1, itr.get().get());
|
||||
itr->AddCallbacks({cb});
|
||||
callbacks.insert(std::make_pair(itr->id(), cb));
|
||||
}
|
||||
}
|
||||
|
||||
/// Add changeRequest to the callback associated with the op.
|
||||
/// \param op_id Operator ID
|
||||
/// \param change_request Pointer to the change request
|
||||
/// \return Status return Status code
|
||||
Status AddChangeRequest(int32_t op_id, ChangeRequestPtr change_request) {
|
||||
RETURN_IF_NOT_OK(callbacks[op_id]->PushChangeRequest(change_request));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
ExecutionTree *tree_;
|
||||
std::map<int32_t, std::shared_ptr<AutotuneCallback>> callbacks;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_MODIFIER_H_
|
|
@ -271,6 +271,11 @@ class QueueList {
|
|||
queue_list_.emplace_back(std::make_unique<Queue<T>>(queue_list_[0]->capacity()));
|
||||
return queue_list_[queue_list_.size() - 1]->Register(vg);
|
||||
}
|
||||
Status RemoveLastQueue() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(queue_list_.size() > 1, "Cannot remove more than the current queues.");
|
||||
queue_list_.pop_back();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Queue contains non-copyable objects, so it cannot be added to a vector due to the vector
|
||||
|
|
|
@ -31,6 +31,8 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
#include "minddata/dataset/engine/tree_modifier.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
@ -145,3 +147,54 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
|||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
||||
// Feature: Basic test for TreeModifier
|
||||
// Description: Create simple tree and modify the tree by adding workers, change queue size and then removing workers
|
||||
// Expectation: No failures.
|
||||
TEST_F(MindDataTestTreeAdapter, TestSimpleTreeModifier) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestSimpleTreeModifier.";
|
||||
|
||||
// Create a CSVDataset, with single CSV file
|
||||
std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse);
|
||||
ASSERT_NE(ds, nullptr);
|
||||
ds = ds->Project({"col1"});
|
||||
ASSERT_NE(ds, nullptr);
|
||||
ds = ds->Repeat(2);
|
||||
ASSERT_NE(ds, nullptr);
|
||||
auto to_number = std::make_shared<text::ToNumber>(mindspore::DataType::kNumberTypeInt32);
|
||||
ASSERT_NE(to_number, nullptr);
|
||||
ds = ds->Map({to_number}, {"col1"}, {"col1"});
|
||||
ds->SetNumWorkers(1);
|
||||
|
||||
auto tree_adapter = std::make_shared<TreeAdapter>();
|
||||
// Disable IR optimization pass
|
||||
tree_adapter->SetOptimize(false);
|
||||
ASSERT_OK(tree_adapter->Compile(ds->IRNode(), 1));
|
||||
|
||||
auto tree_modifier = std::make_unique<TreeModifier>(tree_adapter.get());
|
||||
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>(2));
|
||||
tree_modifier->AddChangeRequest(0, std::make_shared<ResizeConnectorRequest>(20));
|
||||
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>());
|
||||
tree_modifier->AddChangeRequest(0, std::make_shared<ResizeConnectorRequest>(100));
|
||||
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>(10));
|
||||
|
||||
std::vector<int32_t> expected_result = {1, 5, 9, 1, 5, 9};
|
||||
TensorRow row;
|
||||
|
||||
uint64_t i = 0;
|
||||
ASSERT_OK(tree_adapter->GetNext(&row));
|
||||
|
||||
while (row.size() != 0) {
|
||||
auto tensor = row[0];
|
||||
int32_t num;
|
||||
ASSERT_OK(tensor->GetItemAt(&num, {}));
|
||||
EXPECT_EQ(num, expected_result[i]);
|
||||
ASSERT_OK(tree_adapter->GetNext(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
// Expect 6 samples
|
||||
EXPECT_EQ(i, 6);
|
||||
}
|
Loading…
Reference in New Issue