forked from mindspore-Ecosystem/mindspore
commit
a6679511ed
|
@ -106,7 +106,11 @@ void BuildGraphTask::Run() {
|
|||
void RunGraphTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
try {
|
||||
auto graph = session_->GetGraph(graph_id_);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
graph->ResetGraphRunningStatus();
|
||||
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
|
||||
graph->OnRunGraphFinished();
|
||||
UpdateOutputTensors(&outputs_, tensor_to_node_);
|
||||
} catch (const std::exception &e) {
|
||||
MsException::GetInstance().SetException();
|
||||
|
@ -205,6 +209,7 @@ void Executor::OnRunGraphFinished() {
|
|||
if (new_ready_tasks.size() > 0) {
|
||||
task_cond_var_.notify_all();
|
||||
}
|
||||
reenter_cond_var_.notify_all();
|
||||
}
|
||||
|
||||
bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
|
||||
|
@ -215,6 +220,12 @@ bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
auto session = task->session_;
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr) {
|
||||
return graph->IsPreGraphFinished();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -300,6 +311,14 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
|||
SyncRunTask(task);
|
||||
return;
|
||||
}
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr) {
|
||||
if (!graph->IsPostGraphFinished()) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
std::unique_lock<std::mutex> lock(reenter_mutex_);
|
||||
reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
|
||||
}
|
||||
}
|
||||
|
||||
bool ready = IsTaskReady(task);
|
||||
if (!ready) {
|
||||
|
|
|
@ -179,8 +179,10 @@ class Executor {
|
|||
std::string device_name_;
|
||||
std::mutex task_mutex_;
|
||||
std::mutex pending_task_mutex_;
|
||||
std::mutex reenter_mutex_;
|
||||
std::condition_variable task_cond_var_;
|
||||
std::condition_variable sync_cond_var_;
|
||||
std::condition_variable reenter_cond_var_;
|
||||
std::queue<std::shared_ptr<Task>> ready_tasks_;
|
||||
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
|
||||
std::vector<std::shared_ptr<Task>> done_tasks_;
|
||||
|
|
|
@ -17,15 +17,16 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
#include <stack>
|
||||
#include <atomic>
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
|
@ -50,6 +51,51 @@ class KernelGraph : public FuncGraph {
|
|||
summary_node_exist_ = false;
|
||||
stream_distinction_label_ = kInvalidDistincLabel;
|
||||
}
|
||||
|
||||
KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
|
||||
inputs_ = graph.inputs_;
|
||||
child_graph_result_ = graph.child_graph_result_;
|
||||
execution_order_ = graph.execution_order_;
|
||||
graph_id_ = graph.graph_id_;
|
||||
stream_distinction_label_ = graph.stream_distinction_label_;
|
||||
front_backend_anf_map_ = graph.front_backend_anf_map_;
|
||||
backend_front_anf_map_ = graph.backend_front_anf_map_;
|
||||
tensor_to_value_node_map_ = graph.tensor_to_value_node_map_;
|
||||
graph_value_nodes_ = graph.graph_value_nodes_;
|
||||
node_input_num_ = graph.node_input_num_;
|
||||
node_input_edges_ = graph.node_input_edges_;
|
||||
ref_out_in_map_ = graph.ref_out_in_map_;
|
||||
node_output_edges_ = graph.node_output_edges_;
|
||||
summary_nodes_ = graph.summary_nodes_;
|
||||
executable_ = graph.executable_;
|
||||
summary_node_exist_ = graph.summary_node_exist_;
|
||||
valid_inputs_ = graph.valid_inputs_;
|
||||
child_graph_order_ = graph.child_graph_order_;
|
||||
input_ctrl_tensors_ = graph.input_ctrl_tensors_;
|
||||
parent_graph_ = graph.parent_graph_;
|
||||
start_label_ = graph.start_label_;
|
||||
end_goto_ = graph.end_goto_;
|
||||
null_output_ = graph.null_output_;
|
||||
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
|
||||
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
|
||||
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
|
||||
current_epoch_ = graph.current_epoch_;
|
||||
tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_;
|
||||
visited_nodes_ = graph.visited_nodes_;
|
||||
edge_to_ = graph.edge_to_;
|
||||
loop_nodes_ = graph.loop_nodes_;
|
||||
input_nodes_ = graph.input_nodes_;
|
||||
pre_graphs_ = graph.pre_graphs_;
|
||||
post_graphs_ = graph.post_graphs_;
|
||||
size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
|
||||
pre_graph_finished_count_ = pre_graph_finished_count;
|
||||
size_t post_graph_finished_count = graph.post_graph_finished_count_;
|
||||
post_graph_finished_count_ = post_graph_finished_count;
|
||||
first_step_ = graph.first_step_;
|
||||
has_optimizer_ = graph.has_optimizer_;
|
||||
is_dynamic_shape_ = graph.is_dynamic_shape_;
|
||||
}
|
||||
|
||||
~KernelGraph() override;
|
||||
|
||||
MS_DECLARE_PARENT(KernelGraph, FuncGraph);
|
||||
|
@ -189,6 +235,47 @@ class KernelGraph : public FuncGraph {
|
|||
void SetInputNodes();
|
||||
const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
|
||||
bool has_optimizer() const { return has_optimizer_; }
|
||||
// handle graph dependency
|
||||
void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
|
||||
if (graph != nullptr) {
|
||||
pre_graphs_[graph->graph_id()] = graph;
|
||||
}
|
||||
}
|
||||
void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) {
|
||||
if (graph != nullptr) {
|
||||
post_graphs_[graph->graph_id()] = graph;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; }
|
||||
bool IsPostGraphFinished() {
|
||||
if (first_step_) {
|
||||
return true;
|
||||
}
|
||||
return post_graphs_.size() == post_graph_finished_count_;
|
||||
}
|
||||
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
|
||||
void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
|
||||
void ResetGraphRunningStatus() {
|
||||
first_step_ = false;
|
||||
post_graph_finished_count_ = 0;
|
||||
pre_graph_finished_count_ = 0;
|
||||
}
|
||||
void OnRunGraphFinished() {
|
||||
for (auto post_graph : post_graphs_) {
|
||||
auto post_graph_ptr = post_graph.second.lock();
|
||||
if (post_graph_ptr != nullptr) {
|
||||
post_graph_ptr->IncPreGraphFinishedCount();
|
||||
}
|
||||
}
|
||||
for (auto pre_graph : pre_graphs_) {
|
||||
auto pre_graph_ptr = pre_graph.second.lock();
|
||||
if (pre_graph_ptr != nullptr) {
|
||||
pre_graph_ptr->IncPostGraphFinishedCount();
|
||||
}
|
||||
}
|
||||
}
|
||||
// end of handle graph dependency
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -218,6 +305,7 @@ class KernelGraph : public FuncGraph {
|
|||
uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes);
|
||||
void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num);
|
||||
|
||||
// members
|
||||
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
||||
std::vector<AnfNodePtr> child_graph_result_;
|
||||
std::vector<CNodePtr> execution_order_;
|
||||
|
@ -265,6 +353,11 @@ class KernelGraph : public FuncGraph {
|
|||
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
|
||||
std::stack<AnfNodePtr> loop_nodes_;
|
||||
std::vector<AnfNodePtr> input_nodes_;
|
||||
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
|
||||
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
|
||||
std::atomic<size_t> pre_graph_finished_count_{0};
|
||||
std::atomic<size_t> post_graph_finished_count_{0};
|
||||
bool first_step_{true};
|
||||
bool has_optimizer_{false};
|
||||
bool is_dynamic_shape_{false};
|
||||
};
|
||||
|
|
|
@ -358,7 +358,7 @@ GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
|
|||
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
|
||||
auto it = graphs_.find(graph_id);
|
||||
if (it == graphs_.end()) {
|
||||
MS_LOG(WARNING) << "Can't find graph " << graph_id;
|
||||
MS_LOG(INFO) << "Can't find graph " << graph_id;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
|
|
|
@ -57,11 +57,25 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
|
|||
result.outputs = outputs;
|
||||
result.graph_id = kInvalidGraphId;
|
||||
GraphId graph_id = kInvalidGraphId;
|
||||
auto current_session = target_sess_;
|
||||
if (target != target_device_ && !target.empty()) {
|
||||
CreateOtherSession(target);
|
||||
graph_id = other_sess_->CompileGraph(segment, outputs);
|
||||
} else {
|
||||
graph_id = target_sess_->CompileGraph(segment, outputs);
|
||||
current_session = other_sess_;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(current_session);
|
||||
graph_id = current_session->CompileGraph(segment, outputs);
|
||||
segment->graph_id_ = graph_id;
|
||||
auto graph = current_session->GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (auto &pre_segment : segment->pre_segments_) {
|
||||
MS_EXCEPTION_IF_NULL(pre_segment);
|
||||
auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
|
||||
if (pre_graph == nullptr) {
|
||||
pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(pre_graph);
|
||||
pre_graph->AddPostGraph(graph);
|
||||
graph->AddPreGraph(pre_graph);
|
||||
}
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
|
||||
|
|
|
@ -246,6 +246,55 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
|
|||
return result;
|
||||
}
|
||||
|
||||
void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_target,
|
||||
const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) {
|
||||
std::stack<AnfNodePtr> to_visit;
|
||||
std::map<AnfNodePtr, size_t> nodes_ref;
|
||||
std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
|
||||
CalcNodeRefCount(graph, &nodes_ref, &control_edges);
|
||||
to_visit.push(graph->get_return());
|
||||
while (!to_visit.empty()) {
|
||||
auto &node = to_visit.top();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
to_visit.pop();
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto node_inputs = cnode->inputs();
|
||||
auto ctrl_inputs = control_edges.find(node);
|
||||
if (ctrl_inputs != control_edges.end()) {
|
||||
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
|
||||
}
|
||||
GraphSegmentPtr node_segment{nullptr};
|
||||
auto node_iter = node_to_segment.find(node);
|
||||
if (node_iter != node_to_segment.end()) {
|
||||
node_segment = node_iter->second;
|
||||
}
|
||||
for (auto &input : node_inputs) {
|
||||
if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) {
|
||||
GraphSegmentPtr input_segment{nullptr};
|
||||
auto input_iter = node_to_segment.find(input);
|
||||
if (input_iter != node_to_segment.end()) {
|
||||
input_segment = input_iter->second;
|
||||
}
|
||||
if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) {
|
||||
node_segment->AddPreSegment(input_segment);
|
||||
}
|
||||
}
|
||||
auto ref_iter = nodes_ref.find(input);
|
||||
if (ref_iter != nodes_ref.end()) {
|
||||
ref_iter->second--;
|
||||
if (ref_iter->second != 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
to_visit.push(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
|
||||
std::vector<AnfNodePtr> result;
|
||||
std::stack<AnfNodePtr> handle_nodes;
|
||||
|
@ -404,10 +453,10 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
|
|||
auto nodes = TopoSort(graph->get_return());
|
||||
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
|
||||
bool contain_multi_target = ContainMultiTarget(nodes);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (contain_multi_target) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (graph != nullptr) {
|
||||
nodes = SplitSort(graph, default_target);
|
||||
} else {
|
||||
|
@ -417,15 +466,22 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
|
|||
}
|
||||
std::vector<GraphSegmentPtr> segments;
|
||||
std::vector<AnfNodePtr> segment_nodes;
|
||||
std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment;
|
||||
auto new_segment = [&segments, &segment_nodes, &node_to_segment]() {
|
||||
if (segment_nodes.size() != 0) {
|
||||
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
|
||||
segments.emplace_back(segment);
|
||||
for (auto node : segment_nodes) {
|
||||
node_to_segment[node] = segment;
|
||||
}
|
||||
segment_nodes.clear();
|
||||
}
|
||||
};
|
||||
std::string last_target;
|
||||
for (auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsCut(node)) {
|
||||
if (segment_nodes.size() != 0) {
|
||||
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
|
||||
segments.emplace_back(segment);
|
||||
segment_nodes.clear();
|
||||
}
|
||||
new_segment();
|
||||
segment_nodes.emplace_back(node);
|
||||
auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
|
||||
segments.push_back(segment);
|
||||
|
@ -433,10 +489,8 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
|
|||
} else if (node->isa<CNode>()) {
|
||||
if (contain_multi_target) {
|
||||
std::string cur_target = GetCNodeTarget(node);
|
||||
if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) {
|
||||
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
|
||||
segments.emplace_back(segment);
|
||||
segment_nodes.clear();
|
||||
if (cur_target != last_target && !last_target.empty()) {
|
||||
new_segment();
|
||||
}
|
||||
last_target = cur_target;
|
||||
}
|
||||
|
@ -444,6 +498,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
|
|||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Segment size:" << segments.size();
|
||||
if (contain_multi_target) {
|
||||
AddSegmentDependency(graph, default_target, node_to_segment);
|
||||
}
|
||||
return segments;
|
||||
}
|
||||
} // namespace compile
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "base/user_data.h"
|
||||
|
@ -490,8 +491,11 @@ std::string GetCNodeTarget(const AnfNodePtr &node);
|
|||
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
|
||||
struct GraphSegment {
|
||||
GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
|
||||
void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
|
||||
std::vector<AnfNodePtr> nodes_;
|
||||
std::set<std::shared_ptr<GraphSegment>> pre_segments_;
|
||||
bool is_cut_{false};
|
||||
uint32_t graph_id_{0};
|
||||
};
|
||||
using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue