diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc index ef36766069f..fc7efbd1add 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc @@ -36,45 +36,41 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr constexpr auto recursive_level = 3; MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level); - // {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}} + // {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, index}, attr} auto getitem_cnode = getattr_operand_node->cast(); - if (getitem_cnode != nullptr) { + constexpr size_t getitem_inputs_size = 3; + if (getitem_cnode != nullptr && getitem_cnode->size() == getitem_inputs_size) { constexpr size_t prim_index = 0; - auto primitive_node = getitem_cnode->input(prim_index); - auto resolved_getitem_node = primitive_node; - if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) { - auto resolve_getitem_cnode = primitive_node->cast(); + auto resolve_getitem_node = getitem_cnode->input(prim_index); + constexpr size_t resolve_index = 1; + auto resolve_node = getitem_cnode->input(resolve_index); + if (IsPrimitiveCNode(resolve_getitem_node, prim::kPrimResolve) && + IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) { + auto resolve_getitem_cnode = resolve_getitem_node->cast(); auto resolve_getitem_symbol = GetValueNode(resolve_getitem_cnode->input(2)); constexpr auto getitem_symbol = "getitem"; if (resolve_getitem_symbol->symbol() == getitem_symbol) { - auto resolve_getitem_name_space = GetValueNode(resolve_getitem_cnode->input(1)); - resolved_getitem_node = - ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node); - } - } - bool is_getattr_getitem = false; - auto do_signature = dyn_cast(GetValueNode(resolved_getitem_node)); - if (do_signature != nullptr) { - auto &func_value = do_signature->function(); - // The function 'func_value' must be the MultitypeFuncGraph of 'getitem'. - auto multitype_fg_value = dyn_cast(func_value); - constexpr auto getitem_symbol = "getitem"; - if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) { - is_getattr_getitem = true; - } - } - if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) { - is_getattr_getitem = true; - } - if (is_getattr_getitem) { - constexpr size_t resolve_index = 1; - auto resolve_node = getitem_cnode->input(resolve_index); - constexpr size_t position_index = 2; - auto index_node = getitem_cnode->input(position_index); - if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve) && index_node->isa()) { + constexpr size_t position_index = 2; + auto index_node = getitem_cnode->input(position_index); auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); - auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node); - return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr); + auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node); + if (py::isinstance(obj) || py::isinstance(obj)) { + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + auto sequence = obj.cast(); + for (size_t i = 0; i < sequence.size(); ++i) { + auto res = parse::ResolveCellWithAttr(optimizer->manager(), sequence[i], resolve_node, attr); + inputs.emplace_back(res); + } + auto make_tuple_node = getitem_cnode->func_graph()->NewCNodeInOrder(inputs); + auto resolve_getitem_name_space = GetValueNode(resolve_getitem_cnode->input(1)); + auto resolved_getitem_node = + ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node); + auto out = + getitem_cnode->func_graph()->NewCNodeInOrder({resolved_getitem_node, make_tuple_node, index_node}); + return out; + } + return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr); } } } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 70a66ecc5af..0376015b25c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -311,9 +311,9 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons } } // namespace -// Get python object with index from a list. -py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, - const AnfNodePtr &index_node) { +// Get python object with index from a list or the whole list if the index is not fixed. +py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, + const AnfNodePtr &index_node) { MS_EXCEPTION_IF_NULL(node); TraceGuard trace_guard(std::make_shared(node->debug_info())); if (node->func_graph() == nullptr) { @@ -329,15 +329,18 @@ py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const Symbo MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj); } - const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE; - const std::string module = "mindspore._extends.parse.parser"; + MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString(); auto imm_value = GetValueNode(index_node); if (imm_value == nullptr) { - MS_LOG(EXCEPTION) << "Expect an int64 value node, node: " << node->DebugString() - << ", index_node: " << index_node->DebugString(); + MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString() + << ", index_node: " << index_node->DebugString(); + // Index is not fixed, return the whole list. + return obj; } + // It index is a value node, get the item of index directly. + const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE; + const std::string module = "mindspore._extends.parse.parser"; int index = imm_value->value(); - MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index; py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index)); return item_obj; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index 1e7f4c621ae..c5baee0c38b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -179,9 +179,9 @@ class SymbolResolver { }; using SymbolResolverPtr = std::shared_ptr; -// Get python object with index from a list. -py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, - const AnfNodePtr &index_node); +// Get python object with index from a list or the whole list if the index is not fixed. +py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, + const AnfNodePtr &index_node); std::pair GetNamespaceAndSymbol(const AnfNodePtr &node); // Get resolved python object by namespace and symbol. diff --git a/tests/ut/python/pipeline/parse/test_cell_list_getattr.py b/tests/st/rl/test_cell_list_getattr.py similarity index 53% rename from tests/ut/python/pipeline/parse/test_cell_list_getattr.py rename to tests/st/rl/test_cell_list_getattr.py index 9fd4d31af48..ddd9daefc8f 100644 --- a/tests/ut/python/pipeline/parse/test_cell_list_getattr.py +++ b/tests/st/rl/test_cell_list_getattr.py @@ -13,13 +13,12 @@ # limitations under the License. # ============================================================================ """ test a list of cell, and getattr by its item """ +import numpy as np from mindspore import context, nn, dtype, Tensor +from mindspore.ops import operations as P class Actor(nn.Cell): - def __init__(self): - super(Actor, self).__init__() - def act(self, x, y): return x + y @@ -44,4 +43,40 @@ def test_list_item_getattr(): trainer = Trainer(actor_list) x = Tensor([3], dtype=dtype.float32) y = Tensor([6], dtype=dtype.float32) - print(trainer(x, y)) + res = trainer(x, y) + print(f'res: {res}') + expect_res = Tensor([9], dtype=dtype.float32) + assert np.array_equal(res.asnumpy(), expect_res.asnumpy()) + + +class Trainer2(nn.Cell): + def __init__(self, net_list): + super(Trainer2, self).__init__() + self.net_list = net_list + self.less = P.Less() + self.zero_float = Tensor(0, dtype=dtype.float32) + + def construct(self, x, y): + sum_value = self.zero_float + num_actor = 0 + while num_actor < 3: + sum_value += self.net_list[num_actor].act(x, y) + num_actor += 1 + return sum_value + + +def test_list_item_getattr2(): + """ + Feature: getattr by the item from list of cell with a Tensor variable. + Description: Support RL use method in graph mode. + Expectation: No exception. + """ + context.set_context(mode=context.GRAPH_MODE) + actor_list = [Actor(), Actor(), Actor()] + trainer = Trainer2(actor_list) + x = Tensor([3], dtype=dtype.float32) + y = Tensor([6], dtype=dtype.float32) + res = trainer(x, y) + print(f'res: {res}') + expect_res = Tensor([27], dtype=dtype.float32) + assert np.array_equal(res.asnumpy(), expect_res.asnumpy())