forked from mindspore-Ecosystem/mindspore
!5781 add exector group operation
Merge pull request !5781 from kisnwang/async-run-graph
This commit is contained in:
commit
6d9501d5ed
|
@ -16,6 +16,7 @@
|
|||
#include "backend/session/executor.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "utils/comm_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
||||
if (utils::isa<VectorRef>(base_ref)) {
|
||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||
py::tuple output_tensors(ref_list.size());
|
||||
for (size_t i = 0; i < ref_list.size(); ++i) {
|
||||
auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
|
||||
if (utils::isa<tensor::TensorPtr>(output)) {
|
||||
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
output_tensors[i] = tensor_ptr;
|
||||
} else if (utils::isa<PyObjectRef>(output)) {
|
||||
py::object obj = utils::cast<PyObjectRef>(output).object_;
|
||||
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
|
||||
output_tensors[i] = tensor_tuple;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||
}
|
||||
}
|
||||
return output_tensors; // turn tuple to py::object and store in PyObjectRef
|
||||
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
|
||||
return base_ref;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
void CompileNodesTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
|
@ -104,6 +79,10 @@ void RunOpTask::Run() {
|
|||
session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_);
|
||||
}
|
||||
|
||||
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
|
||||
|
||||
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
|
||||
|
||||
Executor::Executor(const std::string &device_name, uint32_t device_id) {
|
||||
device_name_ = device_name;
|
||||
device_id_ = device_id;
|
||||
|
@ -141,22 +120,8 @@ void Executor::WorkerLoop() {
|
|||
} catch (const std::exception &e) {
|
||||
exception_ptr_ = std::current_exception();
|
||||
}
|
||||
|
||||
auto task_type = task->type_;
|
||||
task = nullptr;
|
||||
if (task_type == kCompileNodes) {
|
||||
compile_cond_var_.notify_all();
|
||||
} else if (task_type == kCompileGraph) {
|
||||
compile_cond_var_.notify_all();
|
||||
} else if (task_type == kBuildGraph) {
|
||||
build_cond_var_.notify_all();
|
||||
} else if (task_type == kRunGraph) {
|
||||
run_cond_var_.notify_all();
|
||||
} else if (task_type == kBuildOp) {
|
||||
build_op_cond_var_.notify_all();
|
||||
} else if (task_type == kRunOp) {
|
||||
run_op_cond_var_.notify_all();
|
||||
}
|
||||
sync_cond_var_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL
|
|||
task->output_nodes_ = outputs;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
compile_cond_var_.wait(lock);
|
||||
sync_cond_var_.wait(lock);
|
||||
CheckException();
|
||||
return task->graph_id_;
|
||||
}
|
||||
|
@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph
|
|||
task->func_graph_ = func_graph;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
compile_cond_var_.wait(lock);
|
||||
sync_cond_var_.wait(lock);
|
||||
CheckException();
|
||||
return task->graph_id_;
|
||||
}
|
||||
|
@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
|
|||
task->graph_id_ = graphId;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
build_cond_var_.wait(lock);
|
||||
sync_cond_var_.wait(lock);
|
||||
CheckException();
|
||||
}
|
||||
|
||||
|
@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
|||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
py::gil_scoped_release release;
|
||||
run_cond_var_.wait(lock);
|
||||
sync_cond_var_.wait(lock);
|
||||
CheckException();
|
||||
}
|
||||
|
||||
|
@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c
|
|||
task->tensors_mask_ = tensors_mask;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
build_op_cond_var_.wait(lock);
|
||||
sync_cond_var_.wait(lock);
|
||||
CheckException();
|
||||
}
|
||||
|
||||
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
||||
CheckException();
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
auto task = std::make_shared<RunOpTask>();
|
||||
|
@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info
|
|||
task->input_tensors_ = input_tensors;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
run_op_cond_var_.wait(lock);
|
||||
sync_cond_var_.wait(lock);
|
||||
CheckException();
|
||||
*outputs = task->outputs_;
|
||||
}
|
||||
|
||||
// Trans output to tuple
|
||||
auto output_tensors = TransformBaseRefListToTuple(task->outputs_);
|
||||
if (!utils::isa<PyObjectRef>(output_tensors) ||
|
||||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
|
||||
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
|
||||
}
|
||||
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
|
||||
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
|
||||
return tuple_tensors;
|
||||
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
auto task = std::make_shared<CreateCommGroupTask>();
|
||||
task->group_name_ = group_name;
|
||||
task->ranks_ = ranks;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
sync_cond_var_.wait(lock);
|
||||
return task->result_;
|
||||
}
|
||||
|
||||
bool Executor::DestroyCommGroup(const std::string &group_name) {
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
auto task = std::make_shared<DestroyCommGroupTask>();
|
||||
task->group_name_ = group_name;
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
sync_cond_var_.wait(lock);
|
||||
return task->result_;
|
||||
}
|
||||
|
||||
void Executor::StopWorker() {
|
||||
|
|
|
@ -32,10 +32,22 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/contract.h"
|
||||
#include "utils/comm_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp };
|
||||
enum TaskType {
|
||||
kUnKnown,
|
||||
kExit,
|
||||
kCompileNodes,
|
||||
kCompileGraph,
|
||||
kBuildGraph,
|
||||
kBuildOp,
|
||||
kRunGraph,
|
||||
kRunOp,
|
||||
kCreateCommGroup,
|
||||
kDestroyCommGroup
|
||||
};
|
||||
|
||||
class Task {
|
||||
public:
|
||||
|
@ -106,6 +118,25 @@ class RunOpTask : public Task {
|
|||
VectorRef outputs_;
|
||||
};
|
||||
|
||||
class CreateCommGroupTask : public Task {
|
||||
public:
|
||||
CreateCommGroupTask() { type_ = kCreateCommGroup; }
|
||||
~CreateCommGroupTask() override = default;
|
||||
void Run() override;
|
||||
std::string group_name_;
|
||||
std::vector<uint32_t> ranks_;
|
||||
bool result_;
|
||||
};
|
||||
|
||||
class DestroyCommGroupTask : public Task {
|
||||
public:
|
||||
DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
|
||||
~DestroyCommGroupTask() override = default;
|
||||
void Run() override;
|
||||
std::string group_name_;
|
||||
bool result_;
|
||||
};
|
||||
|
||||
class ExitTask : public Task {
|
||||
public:
|
||||
ExitTask() { type_ = kExit; }
|
||||
|
@ -125,9 +156,11 @@ class Executor {
|
|||
VectorRef *outputs);
|
||||
void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask);
|
||||
py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors);
|
||||
void RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
|
||||
void OnRunGraphFinished();
|
||||
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
|
||||
bool DestroyCommGroup(const std::string &group_name);
|
||||
|
||||
private:
|
||||
void UpdateOutputTensors(VectorRef *outputs,
|
||||
|
@ -143,11 +176,7 @@ class Executor {
|
|||
std::mutex task_mutex_;
|
||||
std::mutex pending_task_mutex_;
|
||||
std::condition_variable task_cond_var_;
|
||||
std::condition_variable compile_cond_var_;
|
||||
std::condition_variable build_cond_var_;
|
||||
std::condition_variable run_cond_var_;
|
||||
std::condition_variable build_op_cond_var_;
|
||||
std::condition_variable run_op_cond_var_;
|
||||
std::condition_variable sync_cond_var_;
|
||||
std::queue<std::shared_ptr<Task>> ready_tasks_;
|
||||
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
|
||||
std::shared_ptr<std::thread> worker_;
|
||||
|
|
|
@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i
|
|||
executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
|
||||
}
|
||||
|
||||
py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
void SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors);
|
||||
executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
|
||||
}
|
||||
|
||||
void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
|
|
|
@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int> &tensors_mask);
|
||||
py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors);
|
||||
void RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
VectorRef *outputs);
|
||||
|
||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
*/
|
||||
|
||||
#include "frontend/parallel/group_manager.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
|
|||
vector<uint32_t> ranks;
|
||||
(void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
|
||||
[](const Device dev) { return (uint32_t)dev.rank(); });
|
||||
// Create group through the CommManager interface
|
||||
bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks);
|
||||
// Create group through the executor
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
|
||||
MS_EXCEPTION_IF_NULL(executor);
|
||||
bool ret = executor->CreateCommGroup(group_name, ranks);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Create group failed, group name is " << group_name;
|
||||
return Status::FAILED;
|
||||
|
@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
|
|||
}
|
||||
}
|
||||
|
||||
Status GroupManager::DestroyGroup(const std::string &group_name) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
|
||||
MS_EXCEPTION_IF_NULL(executor);
|
||||
bool ret = executor->DestroyCommGroup(group_name);
|
||||
if (!ret) {
|
||||
return Status::FAILED;
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
|
||||
std::string name = (*group).name();
|
||||
auto it = groups_.find(name);
|
||||
|
@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
|
|||
return Status::FAILED;
|
||||
}
|
||||
(void)groups_.erase(it);
|
||||
bool ret = CommManager::GetInstance().DestroyGroup(name);
|
||||
if (!ret) {
|
||||
return Status::FAILED;
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
return DestroyGroup(name);
|
||||
}
|
||||
|
||||
Status GroupManager::DestroyAllGroups() {
|
||||
for (auto &it : groups_) {
|
||||
std::string name = it.first;
|
||||
bool ret = CommManager::GetInstance().DestroyGroup(name);
|
||||
if (!ret) {
|
||||
auto ret = DestroyGroup(name);
|
||||
if (ret != Status::SUCCESS) {
|
||||
return Status::FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,6 +65,7 @@ class GroupManager {
|
|||
void Clear();
|
||||
|
||||
private:
|
||||
Status DestroyGroup(const std::string &group_name);
|
||||
// the key is group name (name_)
|
||||
std::map<std::string, Group> groups_;
|
||||
std::string world_group_;
|
||||
|
|
|
@ -19,18 +19,22 @@
|
|||
#include <typeinfo>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <algorithm>
|
||||
|
||||
#include "debug/trace.h"
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/convert_utils_py.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "frontend/operator/composite/do_signature.h"
|
||||
|
@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tens
|
|||
*input_tensors = new_input_tensors;
|
||||
}
|
||||
|
||||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
||||
if (utils::isa<VectorRef>(base_ref)) {
|
||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||
py::tuple output_tensors(ref_list.size());
|
||||
for (size_t i = 0; i < ref_list.size(); ++i) {
|
||||
auto output = TransformBaseRefListToTuple(ref_list[i]);
|
||||
if (utils::isa<tensor::TensorPtr>(output)) {
|
||||
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
output_tensors[i] = tensor_ptr;
|
||||
} else if (utils::isa<PyObjectRef>(output)) {
|
||||
py::object obj = utils::cast<PyObjectRef>(output).object_;
|
||||
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
|
||||
output_tensors[i] = tensor_tuple;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||
}
|
||||
}
|
||||
return std::make_shared<PyObjectRef>(output_tensors);
|
||||
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
|
||||
return base_ref;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||
}
|
||||
}
|
||||
|
||||
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
|
||||
|
@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
|
||||
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, &input_tensors);
|
||||
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
|
||||
|
||||
VectorRef outputs;
|
||||
session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors, &outputs);
|
||||
|
||||
// Trans output to tuple
|
||||
auto output_tensors = TransformBaseRefListToTuple(outputs);
|
||||
if (!utils::isa<PyObjectRef>(output_tensors) ||
|
||||
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
|
||||
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
|
||||
}
|
||||
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
|
||||
py::tuple result = py::cast<py::tuple>(tuple_obj);
|
||||
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
|
||||
|
|
Loading…
Reference in New Issue