forked from mindspore-Ecosystem/mindspore
extend inputs of update if which have multiple outputs
This commit is contained in:
parent
a2e9aadfbf
commit
62cc2990a6
|
@ -50,6 +50,32 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
|
|||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtrList result;
|
||||
for (auto node : nodes) {
|
||||
if (node->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
auto num = node_abstract.size();
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
auto idx_val = SizeToLong(i);
|
||||
|
||||
auto idx = NewValueNode(idx_val);
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
|
||||
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i));
|
||||
tuple_getitem->set_abstract(node_abstract[i]);
|
||||
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
result.push_back(tuple_getitem);
|
||||
}
|
||||
} else {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
|
||||
auto todos = GetUpdateStateList(func_graph);
|
||||
bool changed = false;
|
||||
|
@ -58,6 +84,8 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() <= kUpdateStateRealInput) continue;
|
||||
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
// extend inputs of update if which have multiple outputs
|
||||
inputs = ExtendInputsOfUpdate(inputs, func_graph);
|
||||
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
|
||||
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
|
||||
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
|
||||
|
@ -81,7 +109,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
|
|||
for (auto node : todos) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() <= kUpdateStateRealInput) continue;
|
||||
if (cnode->size() <= kUpdateStateRealInput + 1) continue;
|
||||
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
AbstractBasePtrList abs_list;
|
||||
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
|
||||
|
|
|
@ -39,6 +39,7 @@ class SpreadUpdateState : public Pass {
|
|||
public:
|
||||
SpreadUpdateState() : Pass("spread_update_state") {}
|
||||
~SpreadUpdateState() override = default;
|
||||
AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue