extend inputs of update if which have multiple outputs

This commit is contained in:
wenfangpei 2021-04-30 11:03:23 +08:00
parent a2e9aadfbf
commit 62cc2990a6
2 changed files with 30 additions and 1 deletions

View File

@ -50,6 +50,32 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
return result;
}
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) {
AnfNodePtrList result;
for (auto node : nodes) {
if (node->abstract()->isa<abstract::AbstractTuple>()) {
auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
auto num = node_abstract.size();
for (size_t i = 0; i < num; i++) {
auto idx_val = SizeToLong(i);
auto idx = NewValueNode(idx_val);
MS_EXCEPTION_IF_NULL(idx);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i));
tuple_getitem->set_abstract(node_abstract[i]);
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
result.push_back(tuple_getitem);
}
} else {
result.push_back(node);
}
}
return result;
}
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = GetUpdateStateList(func_graph);
bool changed = false;
@ -58,6 +84,8 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
// extend inputs of update if which have multiple outputs
inputs = ExtendInputsOfUpdate(inputs, func_graph);
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
@ -81,7 +109,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
for (auto node : todos) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
if (cnode->size() <= kUpdateStateRealInput + 1) continue;
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
AbstractBasePtrList abs_list;
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),

View File

@ -39,6 +39,7 @@ class SpreadUpdateState : public Pass {
public:
SpreadUpdateState() : Pass("spread_update_state") {}
~SpreadUpdateState() override = default;
AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph) override;
};