forked from mindspore-Ecosystem/mindspore
1.fix bug of backend common pass covert_tuple_output_to_maketuple
2.attach original inputs to graph when replace call and switch to labelgoto and labelswitch
This commit is contained in:
parent
270a7f2332
commit
bde9c0c6a9
|
@ -30,6 +30,10 @@ namespace {
|
||||||
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
|
||||||
|
AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
std::vector<AnfNodePtr> plant_inputs;
|
std::vector<AnfNodePtr> plant_inputs;
|
||||||
std::vector<int> dyn_input_sizes;
|
std::vector<int> dyn_input_sizes;
|
||||||
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||||
|
|
|
@ -26,22 +26,19 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf,
|
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf) {
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) {
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_anf);
|
MS_EXCEPTION_IF_NULL(tuple_anf);
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(transed_nodes);
|
|
||||||
|
|
||||||
if (!AnfAlgo::IsTupleOutput(tuple_anf)) {
|
if (!AnfAlgo::IsTupleOutput(tuple_anf)) {
|
||||||
return tuple_anf;
|
return tuple_anf;
|
||||||
}
|
}
|
||||||
auto transed_node_it = transed_nodes->find(tuple_anf);
|
|
||||||
if (transed_node_it != transed_nodes->end()) {
|
|
||||||
return transed_node_it->second;
|
|
||||||
}
|
|
||||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||||
|
if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) {
|
||||||
|
return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf);
|
||||||
|
}
|
||||||
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
|
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
|
||||||
(*transed_nodes)[tuple_anf] = make_tuple;
|
kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple);
|
||||||
// replace graph inputs if input is a parameter
|
// replace graph inputs if input is a parameter
|
||||||
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);
|
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);
|
||||||
return make_tuple;
|
return make_tuple;
|
||||||
|
@ -61,7 +58,6 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
||||||
}
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
|
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||||
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
|
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
|
||||||
MS_EXCEPTION_IF_NULL(real_input);
|
MS_EXCEPTION_IF_NULL(real_input);
|
||||||
|
@ -77,7 +73,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
||||||
const auto &input = cnode->inputs()[i];
|
const auto &input = cnode->inputs()[i];
|
||||||
if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
|
if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
|
||||||
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
|
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
|
||||||
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes));
|
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input));
|
||||||
cnode_input_changed = true;
|
cnode_input_changed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AscendControlParser::AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph,
|
||||||
|
const std::vector<AnfNodePtr> orig_inputs) {
|
||||||
|
std::vector<AnfNodePtr> make_tuple_inputs = {
|
||||||
|
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||||
|
std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs));
|
||||||
|
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||||
|
|
||||||
|
InsertDependToGraph(graph, NOT_NULL(make_tuple));
|
||||||
|
}
|
||||||
|
|
||||||
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
||||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||||
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
|
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
|
||||||
|
|
||||||
// 1 get kernel graph
|
// 1 get kernel graph
|
||||||
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
|
std::vector<AnfNodePtr> origin_inputs = cur_node->inputs();
|
||||||
if (kCNodeCallArg >= origin_inputs.size()) {
|
if (kCNodeCallArg >= origin_inputs.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
|
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
|
||||||
}
|
}
|
||||||
|
@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
||||||
cur_node->set_inputs(new_inputs);
|
cur_node->set_inputs(new_inputs);
|
||||||
cur_node->set_abstract(nullptr);
|
cur_node->set_abstract(nullptr);
|
||||||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
|
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
|
||||||
|
origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end());
|
||||||
|
AttachOriginalInputsToGraph(kg, origin_inputs);
|
||||||
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
||||||
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
||||||
// 3.1 branch kernel graph and args
|
// 3.1 branch kernel graph and args
|
||||||
KernelGraphPtr branch_fg;
|
KernelGraphPtr branch_fg;
|
||||||
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
std::vector<AnfNodePtr> origin_inputs;
|
||||||
|
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||||
child_graphs.push_back(branch_fg);
|
child_graphs.push_back(branch_fg);
|
||||||
// 3.2 recurse sub graph
|
// 3.2 recurse sub graph
|
||||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
||||||
new_switch_inputs.push_back(branch_label);
|
new_switch_inputs.push_back(branch_label);
|
||||||
|
AttachOriginalInputsToGraph(kg, origin_inputs);
|
||||||
}
|
}
|
||||||
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
||||||
|
|
||||||
|
@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
||||||
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
||||||
// 3.1 branch kernel graph and args
|
// 3.1 branch kernel graph and args
|
||||||
KernelGraphPtr branch_fg;
|
KernelGraphPtr branch_fg;
|
||||||
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
std::vector<AnfNodePtr> origin_inputs;
|
||||||
|
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||||
child_graphs.push_back(branch_fg);
|
child_graphs.push_back(branch_fg);
|
||||||
// 3.2 recurse sub graph
|
// 3.2 recurse sub graph
|
||||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
||||||
new_switch_inputs.push_back(branch_label);
|
new_switch_inputs.push_back(branch_label);
|
||||||
|
AttachOriginalInputsToGraph(kg, origin_inputs);
|
||||||
}
|
}
|
||||||
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
||||||
cur_node->set_inputs(new_switch_inputs);
|
cur_node->set_inputs(new_switch_inputs);
|
||||||
|
|
|
@ -76,6 +76,7 @@ class AscendControlParser {
|
||||||
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
|
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
|
||||||
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
|
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||||
|
static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs);
|
||||||
};
|
};
|
||||||
class AscendControlParser::ReferenceCounter {
|
class AscendControlParser::ReferenceCounter {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -162,6 +162,19 @@ class KernelGraph : public FuncGraph {
|
||||||
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
||||||
child_graph_result_ = child_graph_result;
|
child_graph_result_ = child_graph_result;
|
||||||
}
|
}
|
||||||
|
void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) {
|
||||||
|
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tuple_parameter_to_make_tuple_map_[param] = make_tuple;
|
||||||
|
}
|
||||||
|
AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr ¶m) {
|
||||||
|
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
|
||||||
|
return tuple_parameter_to_make_tuple_map_[param];
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// remove value node form graph
|
// remove value node form graph
|
||||||
|
@ -229,6 +242,7 @@ class KernelGraph : public FuncGraph {
|
||||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
|
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
|
||||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
||||||
uint32_t current_epoch_;
|
uint32_t current_epoch_;
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||||
|
|
Loading…
Reference in New Issue