1.fix bug of backend common pass covert_tuple_output_to_maketuple

2.attach original inputs to graph when replace call and switch to
labelgoto and labelswitch
This commit is contained in:
wenchunjiang 2020-08-21 11:22:26 +08:00
parent 270a7f2332
commit bde9c0c6a9
5 changed files with 44 additions and 13 deletions

View File

@ -30,6 +30,10 @@ namespace {
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);
if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
return;
}
std::vector<AnfNodePtr> plant_inputs;
std::vector<int> dyn_input_sizes;
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));

View File

@ -26,22 +26,19 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf,
std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) {
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf) {
MS_EXCEPTION_IF_NULL(tuple_anf);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(transed_nodes);
if (!AnfAlgo::IsTupleOutput(tuple_anf)) {
return tuple_anf;
}
auto transed_node_it = transed_nodes->find(tuple_anf);
if (transed_node_it != transed_nodes->end()) {
return transed_node_it->second;
}
auto kernel_graph = graph->cast<KernelGraphPtr>();
if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) {
return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf);
}
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
(*transed_nodes)[tuple_anf] = make_tuple;
kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple);
// replace graph inputs if input is a parameter
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);
return make_tuple;
@ -61,7 +58,6 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
MS_EXCEPTION_IF_NULL(real_input);
@ -77,7 +73,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
const auto &input = cnode->inputs()[i];
if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes));
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input));
cnode_input_changed = true;
}
}

View File

@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
}
}
void AscendControlParser::AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph,
const std::vector<AnfNodePtr> orig_inputs) {
std::vector<AnfNodePtr> make_tuple_inputs = {
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs));
auto make_tuple = graph->NewCNode(make_tuple_inputs);
InsertDependToGraph(graph, NOT_NULL(make_tuple));
}
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
// 1 get kernel graph
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
std::vector<AnfNodePtr> origin_inputs = cur_node->inputs();
if (kCNodeCallArg >= origin_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
}
@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
cur_node->set_inputs(new_inputs);
cur_node->set_abstract(nullptr);
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end());
AttachOriginalInputsToGraph(kg, origin_inputs);
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
}
@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
// 3.1 branch kernel graph and args
KernelGraphPtr branch_fg;
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
std::vector<AnfNodePtr> origin_inputs;
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
child_graphs.push_back(branch_fg);
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
AttachOriginalInputsToGraph(kg, origin_inputs);
}
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
for (size_t i = 0; i < branch_partial.size(); ++i) {
// 3.1 branch kernel graph and args
KernelGraphPtr branch_fg;
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
std::vector<AnfNodePtr> origin_inputs;
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
child_graphs.push_back(branch_fg);
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
AttachOriginalInputsToGraph(kg, origin_inputs);
}
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
cur_node->set_inputs(new_switch_inputs);

View File

@ -76,6 +76,7 @@ class AscendControlParser {
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs);
};
class AscendControlParser::ReferenceCounter {
public:

View File

@ -162,6 +162,19 @@ class KernelGraph : public FuncGraph {
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
child_graph_result_ = child_graph_result;
}
void InsertTupleParameterToMakeTupleMap(const AnfNodePtr &param, const AnfNodePtr &make_tuple) {
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
return;
}
tuple_parameter_to_make_tuple_map_[param] = make_tuple;
}
AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr &param) {
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
return tuple_parameter_to_make_tuple_map_[param];
} else {
return nullptr;
}
}
private:
// remove value node form graph
@ -229,6 +242,7 @@ class KernelGraph : public FuncGraph {
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
uint32_t current_epoch_;
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;