forked from mindspore-Ecosystem/mindspore
refactor vm module for multigraph sink
This commit is contained in:
parent
2860fd9338
commit
6146424596
|
@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
|
||||||
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
|
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) {
|
||||||
auto final_graph = GetGraph(final_graph_id_);
|
|
||||||
MS_EXCEPTION_IF_NULL(final_graph);
|
|
||||||
if (!utils::isa<AnfNodePtr>(output)) {
|
|
||||||
if (!utils::isa<ValuePtr>(output)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
|
||||||
}
|
|
||||||
auto value_ptr = utils::cast<ValuePtr>(output);
|
|
||||||
auto value_node = NewValueNode(value_ptr);
|
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
||||||
value_node->set_kernel_info(kernel_info);
|
|
||||||
value_node->set_abstract(abstract::FromValue(value_ptr));
|
|
||||||
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
|
||||||
final_graph->set_executable(false);
|
|
||||||
MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// get the backend anf node related to the output node of front
|
// get the backend anf node related to the output node of front
|
||||||
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
auto output_from_graph_id = GetGraphIdByNode(node);
|
||||||
auto output_from_graph_id = GetGraphIdByNode(output_anf_node);
|
|
||||||
auto output_from_graph = GetGraph(output_from_graph_id);
|
auto output_from_graph = GetGraph(output_from_graph_id);
|
||||||
MS_EXCEPTION_IF_NULL(output_anf_node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id
|
MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id
|
||||||
<< "] to final graph";
|
<< "] to final graph";
|
||||||
MS_EXCEPTION_IF_NULL(output_from_graph);
|
MS_EXCEPTION_IF_NULL(output_from_graph);
|
||||||
|
auto final_graph = GetGraph(final_graph_id_);
|
||||||
|
MS_EXCEPTION_IF_NULL(final_graph);
|
||||||
// if output is from final graph,it remarks no child graph exist
|
// if output is from final graph,it remarks no child graph exist
|
||||||
if (final_graph_id_ == output_from_graph_id) {
|
if (final_graph_id_ == output_from_graph_id) {
|
||||||
MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString();
|
MS_LOG(INFO) << "No child graph,output is " << node->DebugString();
|
||||||
final_graph->set_output(ConstructOutput({output_anf_node}, final_graph));
|
final_graph->set_output(ConstructOutput({node}, final_graph));
|
||||||
final_graph->set_executable(false);
|
final_graph->set_executable(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
final_graph->set_output(output_from_graph->output());
|
final_graph->set_output(output_from_graph->output());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
|
||||||
|
auto value_node = NewValueNode(value);
|
||||||
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
value_node->set_kernel_info(kernel_info);
|
||||||
|
value_node->set_abstract(abstract::FromValue(value));
|
||||||
|
auto final_graph = GetGraph(final_graph_id_);
|
||||||
|
MS_EXCEPTION_IF_NULL(final_graph);
|
||||||
|
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
||||||
|
final_graph->set_executable(false);
|
||||||
|
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) {
|
||||||
|
for (auto &output : vec_output) {
|
||||||
|
if (utils::isa<AnfNodePtr>(output)) {
|
||||||
|
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
||||||
|
SetFinalGraphOutput(output_anf_node);
|
||||||
|
} else if (utils::isa<ValuePtr>(output)) {
|
||||||
|
auto value = utils::cast<ValuePtr>(output);
|
||||||
|
SetFinalGraphOutput(value);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
||||||
|
if (utils::isa<AnfNodePtr>(output)) {
|
||||||
|
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
||||||
|
SetFinalGraphOutput(output_anf_node);
|
||||||
|
} else if (utils::isa<ValuePtr>(output)) {
|
||||||
|
auto value = utils::cast<ValuePtr>(output);
|
||||||
|
SetFinalGraphOutput(value);
|
||||||
|
} else if (utils::isa<VectorRef>(output)) {
|
||||||
|
auto vec_output = utils::cast<VectorRef>(output);
|
||||||
|
SetFinalGraphOutput(vec_output);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
|
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
|
||||||
auto it = graphs_.find(graph_id);
|
auto it = graphs_.find(graph_id);
|
||||||
if (it == graphs_.end()) {
|
if (it == graphs_.end()) {
|
||||||
|
|
|
@ -88,6 +88,10 @@ class AscendSession : public SessionBasic {
|
||||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
|
size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
|
||||||
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
|
size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
|
||||||
|
|
||||||
|
void SetFinalGraphOutput(const AnfNodePtr &node);
|
||||||
|
void SetFinalGraphOutput(const ValuePtr &value);
|
||||||
|
void SetFinalGraphOutput(const VectorRef &vec_output);
|
||||||
|
|
||||||
// merge execution order list of child graphs
|
// merge execution order list of child graphs
|
||||||
void MergeGraphExecOrder();
|
void MergeGraphExecOrder();
|
||||||
// insert assion op to sync data bettween different graphs
|
// insert assion op to sync data bettween different graphs
|
||||||
|
|
|
@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) {
|
||||||
AddInst(Instruction::kCall, args);
|
AddInst(Instruction::kCall, args);
|
||||||
|
|
||||||
args.clear();
|
args.clear();
|
||||||
args.emplace_back(true);
|
args.emplace_back(node->input(1));
|
||||||
AddInst(Instruction::kSwitchReturn, args);
|
AddInst(Instruction::kSwitchReturn, args);
|
||||||
|
|
||||||
args.clear();
|
args.clear();
|
||||||
|
|
|
@ -141,17 +141,31 @@ void FinalVM::Popsp() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); }
|
||||||
|
|
||||||
|
bool FinalVM::PopStatus() {
|
||||||
|
if (ret_status_.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool status = ret_status_.top();
|
||||||
|
ret_status_.pop();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
|
void FinalVM::DoJmp(const BaseRef &jmp_orig) {
|
||||||
MS_LOG(DEBUG) << "Start";
|
MS_LOG(DEBUG) << "Start";
|
||||||
|
|
||||||
BaseRef jmp = jmp_orig;
|
BaseRef jmp = jmp_orig;
|
||||||
if (backend_->simu_flag()) {
|
if (backend_->simu_flag()) {
|
||||||
|
bool is_switch_call = false;
|
||||||
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
|
if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base
|
||||||
MS_LOG(DEBUG) << "Start jump StructSwitch";
|
MS_LOG(DEBUG) << "Start jump StructSwitch";
|
||||||
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
|
auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp);
|
||||||
jmp = simu_value->fn_;
|
jmp = simu_value->fn_;
|
||||||
backend_->set_curr_switch(simu_value->value_);
|
backend_->set_curr_switch(simu_value->value_);
|
||||||
|
is_switch_call = true;
|
||||||
}
|
}
|
||||||
|
PushStatus(is_switch_call);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
|
if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
|
||||||
|
@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) {
|
||||||
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
|
MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto rv = Ref(-1);
|
||||||
|
if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) {
|
||||||
|
auto &c = args[0];
|
||||||
|
cond_out_[c] = rv;
|
||||||
|
}
|
||||||
|
|
||||||
Pop(1);
|
Pop(1);
|
||||||
Popsp();
|
Popsp();
|
||||||
}
|
}
|
||||||
|
@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) {
|
||||||
int height = utils::cast<int>(args[1]);
|
int height = utils::cast<int>(args[1]);
|
||||||
|
|
||||||
auto rv = Ref(rpos);
|
auto rv = Ref(rpos);
|
||||||
if (backend_->simu_flag() && backend_->is_switch_call()) {
|
if (backend_->simu_flag()) {
|
||||||
backend_->SetSwitchGraph();
|
auto c = backend_->curr_switch();
|
||||||
|
auto status = PopStatus();
|
||||||
|
if (status) {
|
||||||
|
auto iter = cond_out_.find(c);
|
||||||
|
if (iter != cond_out_.end()) {
|
||||||
|
rv = MergeArgs(rv, iter->second);
|
||||||
|
cond_out_.erase(iter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (backend_->is_switch_call()) {
|
||||||
|
backend_->SetSwitchGraph();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Pop(height);
|
Pop(height);
|
||||||
|
@ -383,23 +416,32 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) {
|
||||||
for (size_t i = 0; i < new_args.size(); ++i) {
|
for (size_t i = 0; i < new_args.size(); ++i) {
|
||||||
auto &old_arg = old_args[i];
|
auto &old_arg = old_args[i];
|
||||||
auto &new_arg = new_args[i];
|
auto &new_arg = new_args[i];
|
||||||
if (utils::isa<VectorRef>(old_arg)) {
|
new_arg = MergeArgs(old_arg, new_arg);
|
||||||
auto old_vec_ref = utils::cast<VectorRef>(old_arg);
|
|
||||||
if (utils::isa<VectorRef>(new_arg)) {
|
|
||||||
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
|
|
||||||
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
|
|
||||||
}
|
|
||||||
new_arg = old_vec_ref;
|
|
||||||
} else if (utils::isa<VectorRef>(new_arg)) {
|
|
||||||
auto new_vec_ref = utils::cast<VectorRef>(new_arg);
|
|
||||||
new_vec_ref.push_back(old_arg);
|
|
||||||
new_arg = new_vec_ref;
|
|
||||||
} else {
|
|
||||||
new_arg = VectorRef({new_arg, old_arg});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) {
|
||||||
|
MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString();
|
||||||
|
if (utils::isa<VectorRef>(first)) {
|
||||||
|
auto old_vec_ref = utils::cast<VectorRef>(first);
|
||||||
|
if (utils::isa<VectorRef>(second)) {
|
||||||
|
auto new_vec_ref = utils::cast<VectorRef>(second);
|
||||||
|
std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref));
|
||||||
|
} else {
|
||||||
|
old_vec_ref.push_back(second);
|
||||||
|
}
|
||||||
|
return old_vec_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (utils::isa<VectorRef>(second)) {
|
||||||
|
auto new_vec_ref = utils::cast<VectorRef>(second);
|
||||||
|
new_vec_ref.push_back(first);
|
||||||
|
return new_vec_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
return VectorRef({first, second});
|
||||||
|
}
|
||||||
|
|
||||||
void FinalVM::InstRealSwitch(const VectorRef &args) {
|
void FinalVM::InstRealSwitch(const VectorRef &args) {
|
||||||
const size_t args_size = 3;
|
const size_t args_size = 3;
|
||||||
if (args.size() != args_size) {
|
if (args.size() != args_size) {
|
||||||
|
|
|
@ -125,17 +125,22 @@ class FinalVM {
|
||||||
void Popp();
|
void Popp();
|
||||||
void Pushsp();
|
void Pushsp();
|
||||||
void Popsp();
|
void Popsp();
|
||||||
|
void PushStatus(bool is_switch_call);
|
||||||
|
bool PopStatus();
|
||||||
void DoJmp(const BaseRef &jmp);
|
void DoJmp(const BaseRef &jmp);
|
||||||
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
|
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
|
||||||
|
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
InstSet insts_;
|
InstSet insts_;
|
||||||
std::deque<BaseRef> insts_stack_;
|
std::deque<BaseRef> insts_stack_;
|
||||||
std::stack<int> retp_;
|
std::stack<int> retp_;
|
||||||
std::stack<int> retsp_;
|
std::stack<int> retsp_;
|
||||||
|
std::stack<bool> ret_status_;
|
||||||
int pc_;
|
int pc_;
|
||||||
int sp_;
|
int sp_;
|
||||||
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
|
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_;
|
||||||
|
std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_;
|
||||||
BackendPtr backend_;
|
BackendPtr backend_;
|
||||||
const InstFunctionMap inst_function_map = {
|
const InstFunctionMap inst_function_map = {
|
||||||
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
|
{Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
|
||||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.ops import operations as P
|
||||||
def setup_module(module):
|
def setup_module(module):
|
||||||
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
|
context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend")
|
||||||
|
|
||||||
|
|
||||||
c1 = Tensor([2], mstype.int32)
|
c1 = Tensor([2], mstype.int32)
|
||||||
c2 = Tensor([14], mstype.int32)
|
c2 = Tensor([14], mstype.int32)
|
||||||
c3 = Tensor([1], mstype.int32)
|
c3 = Tensor([1], mstype.int32)
|
||||||
|
@ -149,6 +150,10 @@ def test_if_by_if():
|
||||||
assert output == expect
|
assert output == expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
def test_if_in_if():
|
def test_if_in_if():
|
||||||
output = if_in_if(c1, c2, c3)
|
output = if_in_if(c1, c2, c3)
|
||||||
expect = Tensor([7], mstype.int32)
|
expect = Tensor([7], mstype.int32)
|
||||||
|
@ -194,6 +199,7 @@ def test_while_by_while_in_while():
|
||||||
expect = Tensor([350], mstype.int32)
|
expect = Tensor([350], mstype.int32)
|
||||||
assert output == expect
|
assert output == expect
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
|
Loading…
Reference in New Issue