forked from mindspore-Ecosystem/mindspore
1.fix bug of backend common pass covert_tuple_output_to_maketuple
2.attach original inputs to graph when replace call and switch to labelgoto and labelswitch
This commit is contained in:
parent
270a7f2332
commit
bde9c0c6a9
|
@ -30,6 +30,10 @@ namespace {
|
|||
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
|
||||
AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> plant_inputs;
|
||||
std::vector<int> dyn_input_sizes;
|
||||
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
|
|
|
@ -26,22 +26,19 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) {
|
||||
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(transed_nodes);
|
||||
|
||||
if (!AnfAlgo::IsTupleOutput(tuple_anf)) {
|
||||
return tuple_anf;
|
||||
}
|
||||
auto transed_node_it = transed_nodes->find(tuple_anf);
|
||||
if (transed_node_it != transed_nodes->end()) {
|
||||
return transed_node_it->second;
|
||||
}
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) {
|
||||
return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf);
|
||||
}
|
||||
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
|
||||
(*transed_nodes)[tuple_anf] = make_tuple;
|
||||
kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple);
|
||||
// replace graph inputs if input is a parameter
|
||||
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);
|
||||
return make_tuple;
|
||||
|
@ -61,7 +58,6 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
@ -77,7 +73,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
const auto &input = cnode->inputs()[i];
|
||||
if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
|
||||
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
|
||||
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes));
|
||||
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input));
|
||||
cnode_input_changed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
|
|||
}
|
||||
}
|
||||
|
||||
void AscendControlParser::AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph,
|
||||
const std::vector<AnfNodePtr> orig_inputs) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {
|
||||
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||
std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs));
|
||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||
|
||||
InsertDependToGraph(graph, NOT_NULL(make_tuple));
|
||||
}
|
||||
|
||||
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
|
||||
|
||||
// 1 get kernel graph
|
||||
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
|
||||
std::vector<AnfNodePtr> origin_inputs = cur_node->inputs();
|
||||
if (kCNodeCallArg >= origin_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
|
||||
}
|
||||
|
@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
|||
cur_node->set_inputs(new_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
|
||||
origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end());
|
||||
AttachOriginalInputsToGraph(kg, origin_inputs);
|
||||
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
|
@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
||||
// 3.1 branch kernel graph and args
|
||||
KernelGraphPtr branch_fg;
|
||||
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
std::vector<AnfNodePtr> origin_inputs;
|
||||
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
child_graphs.push_back(branch_fg);
|
||||
// 3.2 recurse sub graph
|
||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
||||
new_switch_inputs.push_back(branch_label);
|
||||
AttachOriginalInputsToGraph(kg, origin_inputs);
|
||||
}
|
||||
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
||||
|
||||
|
@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|||
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
||||
// 3.1 branch kernel graph and args
|
||||
KernelGraphPtr branch_fg;
|
||||
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
std::vector<AnfNodePtr> origin_inputs;
|
||||
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
child_graphs.push_back(branch_fg);
|
||||
// 3.2 recurse sub graph
|
||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
||||
new_switch_inputs.push_back(branch_label);
|
||||
AttachOriginalInputsToGraph(kg, origin_inputs);
|
||||
}
|
||||
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
|
|
|
@ -76,6 +76,7 @@ class AscendControlParser {
|
|||
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
|
||||
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs);
|
||||
};
|
||||
class AscendControlParser::ReferenceCounter {
|
||||
public:
|
||||
|
|
|
@ -162,6 +162,19 @@ class KernelGraph : public FuncGraph {
|
|||
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
||||
child_graph_result_ = child_graph_result;
|
||||
}
|
||||
void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) {
|
||||
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
|
||||
return;
|
||||
}
|
||||
tuple_parameter_to_make_tuple_map_[param] = make_tuple;
|
||||
}
|
||||
AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr ¶m) {
|
||||
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
|
||||
return tuple_parameter_to_make_tuple_map_[param];
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -229,6 +242,7 @@ class KernelGraph : public FuncGraph {
|
|||
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
||||
uint32_t current_epoch_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
Loading…
Reference in New Issue