forked from mindspore-Ecosystem/mindspore
refactor vm module for multigraph sink
This commit is contained in:
parent
2860fd9338
commit
6146424596
|
@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
|
|||
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
|
||||
}
|
||||
|
||||
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
||||
auto final_graph = GetGraph(final_graph_id_);
|
||||
MS_EXCEPTION_IF_NULL(final_graph);
|
||||
if (!utils::isa<AnfNodePtr>(output)) {
|
||||
if (!utils::isa<ValuePtr>(output)) {
|
||||
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
||||
}
|
||||
auto value_ptr = utils::cast<ValuePtr>(output);
|
||||
auto value_node = NewValueNode(value_ptr);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
value_node->set_kernel_info(kernel_info);
|
||||
value_node->set_abstract(abstract::FromValue(value_ptr));
|
||||
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
||||
final_graph->set_executable(false);
|
||||
MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]";
|
||||
return;
|
||||
}
|
||||
void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) {
|
||||
// get the backend anf node related to the output node of front
|
||||
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
||||
auto output_from_graph_id = GetGraphIdByNode(output_anf_node);
|
||||
auto output_from_graph_id = GetGraphIdByNode(node);
|
||||
auto output_from_graph = GetGraph(output_from_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(output_anf_node);
|
||||
MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id
|
||||
<< "] to final graph";
|
||||
MS_EXCEPTION_IF_NULL(output_from_graph);
|
||||
auto final_graph = GetGraph(final_graph_id_);
|
||||
MS_EXCEPTION_IF_NULL(final_graph);
|
||||
// if output is from final graph,it remarks no child graph exist
|
||||
if (final_graph_id_ == output_from_graph_id) {
|
||||
MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString();
|
||||
final_graph->set_output(ConstructOutput({output_anf_node}, final_graph));
|
||||
MS_LOG(INFO) << "No child graph,output is " << node->DebugString();
|
||||
final_graph->set_output(ConstructOutput({node}, final_graph));
|
||||
final_graph->set_executable(false);
|
||||
return;
|
||||
}
|
||||
final_graph->set_output(output_from_graph->output());
|
||||
}
|
||||
|
||||
void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
|
||||
auto value_node = NewValueNode(value);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
value_node->set_kernel_info(kernel_info);
|
||||
value_node->set_abstract(abstract::FromValue(value));
|
||||
auto final_graph = GetGraph(final_graph_id_);
|
||||
MS_EXCEPTION_IF_NULL(final_graph);
|
||||
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
||||
final_graph->set_executable(false);
|
||||
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
|
||||
}
|
||||
|
||||
void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) {
|
||||
for (auto &output : vec_output) {
|
||||
if (utils::isa<AnfNodePtr>(output)) {
|
||||
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
||||
SetFinalGraphOutput(output_anf_node);
|
||||
} else if (utils::isa<ValuePtr>(output)) {
|
||||
auto value = utils::cast<ValuePtr>(output);
|
||||
SetFinalGraphOutput(value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
||||
if (utils::isa<AnfNodePtr>(output)) {
|
||||
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
||||
SetFinalGraphOutput(output_anf_node);
|
||||
} else if (utils::isa<ValuePtr>(output)) {
|
||||
auto value = utils::cast<ValuePtr>(output);
|
||||
SetFinalGraphOutput(value);
|
||||
} else if (utils::isa<VectorRef>(output)) {
|
||||
auto vec_output = utils::cast<VectorRef>(output);
|
||||
SetFinalGraphOutput(vec_output);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
|
||||
auto it = graphs_.find(graph_id);
|
||||
if (it == graphs_.end()) {
|
||||
|
|
|
@ -88,6 +88,10 @@ class AscendSession : public SessionBasic {
|
|||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
|
||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
|
||||
|
||||
void SetFinalGraphOutput(const AnfNodePtr &node);
|
||||
void SetFinalGraphOutput(const ValuePtr &value);
|
||||
void SetFinalGraphOutput(const VectorRef &vec_output);
|
||||
|
||||
// merge execution order list of child graphs
|
||||
void MergeGraphExecOrder();
|
||||
// insert assion op to sync data bettween different graphs
|
||||
|
|
|
@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
|
|||
AddInst(Instruction::kCall, args);
|
||||
|
||||
args.clear();
|
||||
args.emplace_back(true);
|
||||
args.emplace_back(node->input(1));
|
||||
AddInst(Instruction::kSwitchReturn, args);
|
||||
|
||||
args.clear();
|
||||
|
|
|
@ -141,17 +141,31 @@ void FinalVM::Popsp() {
|
|||
}
|
||||
}
|
||||
|
||||
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
|
||||
|
||||
bool FinalVM::PopStatus() {
|
||||
if (ret_status_.empty()) {
|
||||
return false;
|
||||
}
|
||||
bool status = ret_status_.top();
|
||||
ret_status_.pop();
|
||||
return status;
|
||||
}
|
||||
|
||||
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
|
||||
MS_LOG(DEBUG) << "Start";
|
||||
|
||||
BaseRef jmp = jmp_orig;
|
||||
if (backend_->simu_flag()) {
|
||||
bool is_switch_call = false;
|
||||
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
|
||||
MS_LOG(DEBUG) << "Start jump StructSwitch";
|
||||
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
|
||||
jmp = simu_value->fn_;
|
||||
backend_->set_curr_switch(simu_value->value_);
|
||||
is_switch_call = true;
|
||||
}
|
||||
PushStatus(is_switch_call);
|
||||
}
|
||||
|
||||
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
|
||||
|
@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
|
|||
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
|
||||
return;
|
||||
}
|
||||
|
||||
auto rv = Ref(-1);
|
||||
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
|
||||
auto &c = args[0];
|
||||
cond_out_[c] = rv;
|
||||
}
|
||||
|
||||
Pop(1);
|
||||
Popsp();
|
||||
}
|
||||
|
@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) {
|
|||
int height = utils::cast<int>(args[1]);
|
||||
|
||||
auto rv = Ref(rpos);
|
||||
if (backend_->simu_flag() && backend_->is_switch_call()) {
|
||||
backend_->SetSwitchGraph();
|
||||
if (backend_->simu_flag()) {
|
||||
auto c = backend_->curr_switch();
|
||||
auto status = PopStatus();
|
||||
if (status) {
|
||||
auto iter = cond_out_.find(c);
|
||||
if (iter != cond_out_.end()) {
|
||||
rv = MergeArgs(rv, iter->second);
|
||||
cond_out_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
if (backend_->is_switch_call()) {
|
||||
backend_->SetSwitchGraph();
|
||||
}
|
||||
}
|
||||
|
||||
Pop(height);
|
||||
|
@ -383,23 +416,32 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
|
|||
for (size_t i = 0; i < new_args.size(); ++i) {
|
||||
auto &old_arg = old_args[i];
|
||||
auto &new_arg = new_args[i];
|
||||
if (utils::isa<VectorRef>(old_arg)) {
|
||||
auto old_vec_ref = utils::cast<VectorRef>(old_arg);
|
||||
if (utils::isa<VectorRef>(new_arg)) {
|
||||
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
|
||||
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
|
||||
}
|
||||
new_arg = old_vec_ref;
|
||||
} else if (utils::isa<VectorRef>(new_arg)) {
|
||||
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
|
||||
new_vec_ref.push_back(old_arg);
|
||||
new_arg = new_vec_ref;
|
||||
} else {
|
||||
new_arg = VectorRef({new_arg, old_arg});
|
||||
}
|
||||
new_arg = MergeArgs(old_arg, new_arg);
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
|
||||
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
|
||||
if (utils::isa<VectorRef>(first)) {
|
||||
auto old_vec_ref = utils::cast<VectorRef>(first);
|
||||
if (utils::isa<VectorRef>(second)) {
|
||||
auto new_vec_ref = utils::cast<VectorRef>(second);
|
||||
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
|
||||
} else {
|
||||
old_vec_ref.push_back(second);
|
||||
}
|
||||
return old_vec_ref;
|
||||
}
|
||||
|
||||
if (utils::isa<VectorRef>(second)) {
|
||||
auto new_vec_ref = utils::cast<VectorRef>(second);
|
||||
new_vec_ref.push_back(first);
|
||||
return new_vec_ref;
|
||||
}
|
||||
|
||||
return VectorRef({first, second});
|
||||
}
|
||||
|
||||
void FinalVM::InstRealSwitch(const VectorRef &args) {
|
||||
const size_t args_size = 3;
|
||||
if (args.size() != args_size) {
|
||||
|
|
|
@ -125,17 +125,22 @@ class FinalVM {
|
|||
void Popp();
|
||||
void Pushsp();
|
||||
void Popsp();
|
||||
void PushStatus(bool is_switch_call);
|
||||
bool PopStatus();
|
||||
void DoJmp(const BaseRef &jmp);
|
||||
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
|
||||
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
|
||||
|
||||
private:
|
||||
InstSet insts_;
|
||||
std::deque<BaseRef> insts_stack_;
|
||||
std::stack<int> retp_;
|
||||
std::stack<int> retsp_;
|
||||
std::stack<bool> ret_status_;
|
||||
int pc_;
|
||||
int sp_;
|
||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
|
||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
|
||||
BackendPtr backend_;
|
||||
const InstFunctionMap inst_function_map = {
|
||||
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.ops import operations as P
|
|||
def setup_module(module):
|
||||
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
|
||||
|
||||
|
||||
c1 = Tensor([2], mstype.int32)
|
||||
c2 = Tensor([14], mstype.int32)
|
||||
c3 = Tensor([1], mstype.int32)
|
||||
|
@ -149,6 +150,10 @@ def test_if_by_if():
|
|||
assert output == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_in_if():
|
||||
output = if_in_if(c1, c2, c3)
|
||||
expect = Tensor([7], mstype.int32)
|
||||
|
@ -194,6 +199,7 @@ def test_while_by_while_in_while():
|
|||
expect = Tensor([350], mstype.int32)
|
||||
assert output == expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue