Support variable index of getitem during handling cell list getattr.

This commit is contained in:
Zhang Qinghua 2021-12-14 21:51:28 +08:00
parent 9558ba49d8
commit ea8c47e981
4 changed files with 82 additions and 48 deletions

View File

@ -36,45 +36,41 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
constexpr auto recursive_level = 3; constexpr auto recursive_level = 3;
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level); MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}} // {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, index}, attr}
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>(); auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
if (getitem_cnode != nullptr) { constexpr size_t getitem_inputs_size = 3;
if (getitem_cnode != nullptr && getitem_cnode->size() == getitem_inputs_size) {
constexpr size_t prim_index = 0; constexpr size_t prim_index = 0;
auto primitive_node = getitem_cnode->input(prim_index); auto resolve_getitem_node = getitem_cnode->input(prim_index);
auto resolved_getitem_node = primitive_node; constexpr size_t resolve_index = 1;
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) { auto resolve_node = getitem_cnode->input(resolve_index);
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>(); if (IsPrimitiveCNode(resolve_getitem_node, prim::kPrimResolve) &&
IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
auto resolve_getitem_cnode = resolve_getitem_node->cast<CNodePtr>();
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2)); auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
constexpr auto getitem_symbol = "getitem"; constexpr auto getitem_symbol = "getitem";
if (resolve_getitem_symbol->symbol() == getitem_symbol) { if (resolve_getitem_symbol->symbol() == getitem_symbol) {
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
resolved_getitem_node =
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
}
}
bool is_getattr_getitem = false;
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
if (do_signature != nullptr) {
auto &func_value = do_signature->function();
// The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
constexpr auto getitem_symbol = "getitem";
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
is_getattr_getitem = true;
}
}
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
is_getattr_getitem = true;
}
if (is_getattr_getitem) {
constexpr size_t resolve_index = 1;
auto resolve_node = getitem_cnode->input(resolve_index);
constexpr size_t position_index = 2; constexpr size_t position_index = 2;
auto index_node = getitem_cnode->input(position_index); auto index_node = getitem_cnode->input(position_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve) && index_node->isa<ValueNode>()) {
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node); auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr); if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto sequence = obj.cast<py::sequence>();
for (size_t i = 0; i < sequence.size(); ++i) {
auto res = parse::ResolveCellWithAttr(optimizer->manager(), sequence[i], resolve_node, attr);
inputs.emplace_back(res);
}
auto make_tuple_node = getitem_cnode->func_graph()->NewCNodeInOrder(inputs);
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
auto resolved_getitem_node =
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
auto out =
getitem_cnode->func_graph()->NewCNodeInOrder({resolved_getitem_node, make_tuple_node, index_node});
return out;
}
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr);
} }
} }
} }

View File

@ -311,8 +311,8 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
} }
} // namespace } // namespace
// Get python object with index from a list. // Get python object with index from a list or the whole list if the index is not fixed.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) { const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info())); TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
@ -329,15 +329,18 @@ py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const Symbo
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj); MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
} }
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE; MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
const std::string module = "mindspore._extends.parse.parser";
auto imm_value = GetValueNode<Int64ImmPtr>(index_node); auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
if (imm_value == nullptr) { if (imm_value == nullptr) {
MS_LOG(EXCEPTION) << "Expect an int64 value node, node: " << node->DebugString() MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
<< ", index_node: " << index_node->DebugString(); << ", index_node: " << index_node->DebugString();
// Index is not fixed, return the whole list.
return obj;
} }
// It index is a value node, get the item of index directly.
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
const std::string module = "mindspore._extends.parse.parser";
int index = imm_value->value(); int index = imm_value->value();
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index)); py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
return item_obj; return item_obj;
} }

View File

@ -179,8 +179,8 @@ class SymbolResolver {
}; };
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>; using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
// Get python object with index from a list. // Get python object with index from a list or the whole list if the index is not fixed.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node); const AnfNodePtr &index_node);
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node); std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);

View File

@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test a list of cell, and getattr by its item """ """ test a list of cell, and getattr by its item """
import numpy as np
from mindspore import context, nn, dtype, Tensor from mindspore import context, nn, dtype, Tensor
from mindspore.ops import operations as P
class Actor(nn.Cell): class Actor(nn.Cell):
def __init__(self):
super(Actor, self).__init__()
def act(self, x, y): def act(self, x, y):
return x + y return x + y
@ -44,4 +43,40 @@ def test_list_item_getattr():
trainer = Trainer(actor_list) trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32) x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32) y = Tensor([6], dtype=dtype.float32)
print(trainer(x, y)) res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([9], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
class Trainer2(nn.Cell):
def __init__(self, net_list):
super(Trainer2, self).__init__()
self.net_list = net_list
self.less = P.Less()
self.zero_float = Tensor(0, dtype=dtype.float32)
def construct(self, x, y):
sum_value = self.zero_float
num_actor = 0
while num_actor < 3:
sum_value += self.net_list[num_actor].act(x, y)
num_actor += 1
return sum_value
def test_list_item_getattr2():
"""
Feature: getattr by the item from list of cell with a Tensor variable.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor(), Actor(), Actor()]
trainer = Trainer2(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([27], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())