Support variable index of getitem during handling cell list getattr.

This commit is contained in:
Zhang Qinghua 2021-12-14 21:51:28 +08:00
parent 9558ba49d8
commit ea8c47e981
4 changed files with 82 additions and 48 deletions

View File

@ -36,45 +36,41 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
constexpr auto recursive_level = 3;
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}}
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, index}, attr}
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
if (getitem_cnode != nullptr) {
constexpr size_t getitem_inputs_size = 3;
if (getitem_cnode != nullptr && getitem_cnode->size() == getitem_inputs_size) {
constexpr size_t prim_index = 0;
auto primitive_node = getitem_cnode->input(prim_index);
auto resolved_getitem_node = primitive_node;
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) {
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>();
auto resolve_getitem_node = getitem_cnode->input(prim_index);
constexpr size_t resolve_index = 1;
auto resolve_node = getitem_cnode->input(resolve_index);
if (IsPrimitiveCNode(resolve_getitem_node, prim::kPrimResolve) &&
IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
auto resolve_getitem_cnode = resolve_getitem_node->cast<CNodePtr>();
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
constexpr auto getitem_symbol = "getitem";
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
resolved_getitem_node =
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
}
}
bool is_getattr_getitem = false;
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
if (do_signature != nullptr) {
auto &func_value = do_signature->function();
// The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
constexpr auto getitem_symbol = "getitem";
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
is_getattr_getitem = true;
}
}
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
is_getattr_getitem = true;
}
if (is_getattr_getitem) {
constexpr size_t resolve_index = 1;
auto resolve_node = getitem_cnode->input(resolve_index);
constexpr size_t position_index = 2;
auto index_node = getitem_cnode->input(position_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve) && index_node->isa<ValueNode>()) {
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node);
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr);
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto sequence = obj.cast<py::sequence>();
for (size_t i = 0; i < sequence.size(); ++i) {
auto res = parse::ResolveCellWithAttr(optimizer->manager(), sequence[i], resolve_node, attr);
inputs.emplace_back(res);
}
auto make_tuple_node = getitem_cnode->func_graph()->NewCNodeInOrder(inputs);
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
auto resolved_getitem_node =
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
auto out =
getitem_cnode->func_graph()->NewCNodeInOrder({resolved_getitem_node, make_tuple_node, index_node});
return out;
}
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr);
}
}
}

View File

@ -311,8 +311,8 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
}
} // namespace
// Get python object with index from a list.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
// Get python object with index from a list or the whole list if the index is not fixed.
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
@ -329,15 +329,18 @@ py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const Symbo
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
}
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
const std::string module = "mindspore._extends.parse.parser";
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
if (imm_value == nullptr) {
MS_LOG(EXCEPTION) << "Expect an int64 value node, node: " << node->DebugString()
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
<< ", index_node: " << index_node->DebugString();
// Index is not fixed, return the whole list.
return obj;
}
// It index is a value node, get the item of index directly.
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
const std::string module = "mindspore._extends.parse.parser";
int index = imm_value->value();
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
return item_obj;
}

View File

@ -179,8 +179,8 @@ class SymbolResolver {
};
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
// Get python object with index from a list.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
// Get python object with index from a list or the whole list if the index is not fixed.
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node);
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);

View File

@ -13,13 +13,12 @@
# limitations under the License.
# ============================================================================
""" test a list of cell, and getattr by its item """
import numpy as np
from mindspore import context, nn, dtype, Tensor
from mindspore.ops import operations as P
class Actor(nn.Cell):
def __init__(self):
super(Actor, self).__init__()
def act(self, x, y):
return x + y
@ -44,4 +43,40 @@ def test_list_item_getattr():
trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
print(trainer(x, y))
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([9], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
class Trainer2(nn.Cell):
def __init__(self, net_list):
super(Trainer2, self).__init__()
self.net_list = net_list
self.less = P.Less()
self.zero_float = Tensor(0, dtype=dtype.float32)
def construct(self, x, y):
sum_value = self.zero_float
num_actor = 0
while num_actor < 3:
sum_value += self.net_list[num_actor].act(x, y)
num_actor += 1
return sum_value
def test_list_item_getattr2():
"""
Feature: getattr by the item from list of cell with a Tensor variable.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor(), Actor(), Actor()]
trainer = Trainer2(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([27], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())