refactor vm module for multigraph sink

This commit is contained in:
rick_sanchez 2020-04-30 15:55:52 +08:00
parent 2860fd9338
commit 6146424596
6 changed files with 123 additions and 41 deletions

73
mindspore/ccsrc/session/ascend_session.cc Executable file → Normal file
View File

@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
} }
void AscendSession::SetFinalGraphOutput(const BaseRef &output) { void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) {
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
if (!utils::isa<AnfNodePtr>(output)) {
if (!utils::isa<ValuePtr>(output)) {
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
}
auto value_ptr = utils::cast<ValuePtr>(output);
auto value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(value_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
value_node->set_abstract(abstract::FromValue(value_ptr));
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
final_graph->set_executable(false);
MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]";
return;
}
// get the backend anf node related to the output node of front // get the backend anf node related to the output node of front
auto output_anf_node = utils::cast<AnfNodePtr>(output); auto output_from_graph_id = GetGraphIdByNode(node);
auto output_from_graph_id = GetGraphIdByNode(output_anf_node);
auto output_from_graph = GetGraph(output_from_graph_id); auto output_from_graph = GetGraph(output_from_graph_id);
MS_EXCEPTION_IF_NULL(output_anf_node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id
<< "] to final graph"; << "] to final graph";
MS_EXCEPTION_IF_NULL(output_from_graph); MS_EXCEPTION_IF_NULL(output_from_graph);
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
// if output is from final graph,it remarks no child graph exist // if output is from final graph,it remarks no child graph exist
if (final_graph_id_ == output_from_graph_id) { if (final_graph_id_ == output_from_graph_id) {
MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString(); MS_LOG(INFO) << "No child graph,output is " << node->DebugString();
final_graph->set_output(ConstructOutput({output_anf_node}, final_graph)); final_graph->set_output(ConstructOutput({node}, final_graph));
final_graph->set_executable(false); final_graph->set_executable(false);
return; return;
} }
final_graph->set_output(output_from_graph->output()); final_graph->set_output(output_from_graph->output());
} }
void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
auto value_node = NewValueNode(value);
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
value_node->set_abstract(abstract::FromValue(value));
auto final_graph = GetGraph(final_graph_id_);
MS_EXCEPTION_IF_NULL(final_graph);
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
final_graph->set_executable(false);
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
}
void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) {
for (auto &output : vec_output) {
if (utils::isa<AnfNodePtr>(output)) {
auto output_anf_node = utils::cast<AnfNodePtr>(output);
SetFinalGraphOutput(output_anf_node);
} else if (utils::isa<ValuePtr>(output)) {
auto value = utils::cast<ValuePtr>(output);
SetFinalGraphOutput(value);
} else {
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
}
}
}
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
if (utils::isa<AnfNodePtr>(output)) {
auto output_anf_node = utils::cast<AnfNodePtr>(output);
SetFinalGraphOutput(output_anf_node);
} else if (utils::isa<ValuePtr>(output)) {
auto value = utils::cast<ValuePtr>(output);
SetFinalGraphOutput(value);
} else if (utils::isa<VectorRef>(output)) {
auto vec_output = utils::cast<VectorRef>(output);
SetFinalGraphOutput(vec_output);
} else {
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
}
}
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) { KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
auto it = graphs_.find(graph_id); auto it = graphs_.find(graph_id);
if (it == graphs_.end()) { if (it == graphs_.end()) {

View File

@ -88,6 +88,10 @@ class AscendSession : public SessionBasic {
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
void SetFinalGraphOutput(const AnfNodePtr &node);
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);
// merge execution order list of child graphs // merge execution order list of child graphs
void MergeGraphExecOrder(); void MergeGraphExecOrder();
// insert assion op to sync data bettween different graphs // insert assion op to sync data bettween different graphs

View File

@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
AddInst(Instruction::kCall, args); AddInst(Instruction::kCall, args);
args.clear(); args.clear();
args.emplace_back(true); args.emplace_back(node->input(1));
AddInst(Instruction::kSwitchReturn, args); AddInst(Instruction::kSwitchReturn, args);
args.clear(); args.clear();

View File

@ -141,17 +141,31 @@ void FinalVM::Popsp() {
} }
} }
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
bool FinalVM::PopStatus() {
if (ret_status_.empty()) {
return false;
}
bool status = ret_status_.top();
ret_status_.pop();
return status;
}
void FinalVM::DoJmp(const BaseRef &jmp_orig) { void FinalVM::DoJmp(const BaseRef &jmp_orig) {
MS_LOG(DEBUG) << "Start"; MS_LOG(DEBUG) << "Start";
BaseRef jmp = jmp_orig; BaseRef jmp = jmp_orig;
if (backend_->simu_flag()) { if (backend_->simu_flag()) {
bool is_switch_call = false;
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
MS_LOG(DEBUG) << "Start jump StructSwitch"; MS_LOG(DEBUG) << "Start jump StructSwitch";
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp); auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
jmp = simu_value->fn_; jmp = simu_value->fn_;
backend_->set_curr_switch(simu_value->value_); backend_->set_curr_switch(simu_value->value_);
is_switch_call = true;
} }
PushStatus(is_switch_call);
} }
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
return; return;
} }
auto rv = Ref(-1);
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
auto &c = args[0];
cond_out_[c] = rv;
}
Pop(1); Pop(1);
Popsp(); Popsp();
} }
@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) {
int height = utils::cast<int>(args[1]); int height = utils::cast<int>(args[1]);
auto rv = Ref(rpos); auto rv = Ref(rpos);
if (backend_->simu_flag() && backend_->is_switch_call()) { if (backend_->simu_flag()) {
backend_->SetSwitchGraph(); auto c = backend_->curr_switch();
auto status = PopStatus();
if (status) {
auto iter = cond_out_.find(c);
if (iter != cond_out_.end()) {
rv = MergeArgs(rv, iter->second);
cond_out_.erase(iter);
}
}
if (backend_->is_switch_call()) {
backend_->SetSwitchGraph();
}
} }
Pop(height); Pop(height);
@ -383,23 +416,32 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
for (size_t i = 0; i < new_args.size(); ++i) { for (size_t i = 0; i < new_args.size(); ++i) {
auto &old_arg = old_args[i]; auto &old_arg = old_args[i];
auto &new_arg = new_args[i]; auto &new_arg = new_args[i];
if (utils::isa<VectorRef>(old_arg)) { new_arg = MergeArgs(old_arg, new_arg);
auto old_vec_ref = utils::cast<VectorRef>(old_arg);
if (utils::isa<VectorRef>(new_arg)) {
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
}
new_arg = old_vec_ref;
} else if (utils::isa<VectorRef>(new_arg)) {
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
new_vec_ref.push_back(old_arg);
new_arg = new_vec_ref;
} else {
new_arg = VectorRef({new_arg, old_arg});
}
} }
} }
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
if (utils::isa<VectorRef>(first)) {
auto old_vec_ref = utils::cast<VectorRef>(first);
if (utils::isa<VectorRef>(second)) {
auto new_vec_ref = utils::cast<VectorRef>(second);
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
} else {
old_vec_ref.push_back(second);
}
return old_vec_ref;
}
if (utils::isa<VectorRef>(second)) {
auto new_vec_ref = utils::cast<VectorRef>(second);
new_vec_ref.push_back(first);
return new_vec_ref;
}
return VectorRef({first, second});
}
void FinalVM::InstRealSwitch(const VectorRef &args) { void FinalVM::InstRealSwitch(const VectorRef &args) {
const size_t args_size = 3; const size_t args_size = 3;
if (args.size() != args_size) { if (args.size() != args_size) {

View File

@ -125,17 +125,22 @@ class FinalVM {
void Popp(); void Popp();
void Pushsp(); void Pushsp();
void Popsp(); void Popsp();
void PushStatus(bool is_switch_call);
bool PopStatus();
void DoJmp(const BaseRef &jmp); void DoJmp(const BaseRef &jmp);
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
private: private:
InstSet insts_; InstSet insts_;
std::deque<BaseRef> insts_stack_; std::deque<BaseRef> insts_stack_;
std::stack<int> retp_; std::stack<int> retp_;
std::stack<int> retsp_; std::stack<int> retsp_;
std::stack<bool> ret_status_;
int pc_; int pc_;
int sp_; int sp_;
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_; std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
BackendPtr backend_; BackendPtr backend_;
const InstFunctionMap inst_function_map = { const InstFunctionMap inst_function_map = {
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},

View File

@ -26,6 +26,7 @@ from mindspore.ops import operations as P
def setup_module(module): def setup_module(module):
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend") context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
c1 = Tensor([2], mstype.int32) c1 = Tensor([2], mstype.int32)
c2 = Tensor([14], mstype.int32) c2 = Tensor([14], mstype.int32)
c3 = Tensor([1], mstype.int32) c3 = Tensor([1], mstype.int32)
@ -149,6 +150,10 @@ def test_if_by_if():
assert output == expect assert output == expect
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_if_in_if(): def test_if_in_if():
output = if_in_if(c1, c2, c3) output = if_in_if(c1, c2, c3)
expect = Tensor([7], mstype.int32) expect = Tensor([7], mstype.int32)
@ -194,6 +199,7 @@ def test_while_by_while_in_while():
expect = Tensor([350], mstype.int32) expect = Tensor([350], mstype.int32)
assert output == expect assert output == expect
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training