forked from mindspore-Ecosystem/mindspore
Support variable index of getitem during handling cell list getattr.
This commit is contained in:
parent
9558ba49d8
commit
ea8c47e981
|
@ -36,45 +36,41 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
constexpr auto recursive_level = 3;
|
||||
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
|
||||
|
||||
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}}
|
||||
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, index}, attr}
|
||||
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
|
||||
if (getitem_cnode != nullptr) {
|
||||
constexpr size_t getitem_inputs_size = 3;
|
||||
if (getitem_cnode != nullptr && getitem_cnode->size() == getitem_inputs_size) {
|
||||
constexpr size_t prim_index = 0;
|
||||
auto primitive_node = getitem_cnode->input(prim_index);
|
||||
auto resolved_getitem_node = primitive_node;
|
||||
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) {
|
||||
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>();
|
||||
auto resolve_getitem_node = getitem_cnode->input(prim_index);
|
||||
constexpr size_t resolve_index = 1;
|
||||
auto resolve_node = getitem_cnode->input(resolve_index);
|
||||
if (IsPrimitiveCNode(resolve_getitem_node, prim::kPrimResolve) &&
|
||||
IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
auto resolve_getitem_cnode = resolve_getitem_node->cast<CNodePtr>();
|
||||
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
|
||||
constexpr auto getitem_symbol = "getitem";
|
||||
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
|
||||
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
|
||||
resolved_getitem_node =
|
||||
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
|
||||
}
|
||||
}
|
||||
bool is_getattr_getitem = false;
|
||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
|
||||
if (do_signature != nullptr) {
|
||||
auto &func_value = do_signature->function();
|
||||
// The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
|
||||
auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
|
||||
constexpr auto getitem_symbol = "getitem";
|
||||
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
|
||||
is_getattr_getitem = true;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
|
||||
is_getattr_getitem = true;
|
||||
}
|
||||
if (is_getattr_getitem) {
|
||||
constexpr size_t resolve_index = 1;
|
||||
auto resolve_node = getitem_cnode->input(resolve_index);
|
||||
constexpr size_t position_index = 2;
|
||||
auto index_node = getitem_cnode->input(position_index);
|
||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve) && index_node->isa<ValueNode>()) {
|
||||
constexpr size_t position_index = 2;
|
||||
auto index_node = getitem_cnode->input(position_index);
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
||||
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr);
|
||||
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto sequence = obj.cast<py::sequence>();
|
||||
for (size_t i = 0; i < sequence.size(); ++i) {
|
||||
auto res = parse::ResolveCellWithAttr(optimizer->manager(), sequence[i], resolve_node, attr);
|
||||
inputs.emplace_back(res);
|
||||
}
|
||||
auto make_tuple_node = getitem_cnode->func_graph()->NewCNodeInOrder(inputs);
|
||||
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
|
||||
auto resolved_getitem_node =
|
||||
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
|
||||
auto out =
|
||||
getitem_cnode->func_graph()->NewCNodeInOrder({resolved_getitem_node, make_tuple_node, index_node});
|
||||
return out;
|
||||
}
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -311,9 +311,9 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
|
|||
}
|
||||
} // namespace
|
||||
|
||||
// Get python object with index from a list.
|
||||
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node) {
|
||||
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
if (node->func_graph() == nullptr) {
|
||||
|
@ -329,15 +329,18 @@ py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const Symbo
|
|||
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
|
||||
}
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
|
||||
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
|
||||
if (imm_value == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Expect an int64 value node, node: " << node->DebugString()
|
||||
<< ", index_node: " << index_node->DebugString();
|
||||
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
|
||||
<< ", index_node: " << index_node->DebugString();
|
||||
// Index is not fixed, return the whole list.
|
||||
return obj;
|
||||
}
|
||||
// It index is a value node, get the item of index directly.
|
||||
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
int index = imm_value->value();
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
|
||||
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
|
||||
return item_obj;
|
||||
}
|
||||
|
|
|
@ -179,9 +179,9 @@ class SymbolResolver {
|
|||
};
|
||||
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
|
||||
|
||||
// Get python object with index from a list.
|
||||
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node);
|
||||
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node);
|
||||
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
|
||||
|
||||
// Get resolved python object by namespace and symbol.
|
||||
|
|
|
@ -13,13 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test a list of cell, and getattr by its item """
|
||||
import numpy as np
|
||||
from mindspore import context, nn, dtype, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Actor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
def act(self, x, y):
|
||||
return x + y
|
||||
|
||||
|
@ -44,4 +43,40 @@ def test_list_item_getattr():
|
|||
trainer = Trainer(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
print(trainer(x, y))
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
class Trainer2(nn.Cell):
|
||||
def __init__(self, net_list):
|
||||
super(Trainer2, self).__init__()
|
||||
self.net_list = net_list
|
||||
self.less = P.Less()
|
||||
self.zero_float = Tensor(0, dtype=dtype.float32)
|
||||
|
||||
def construct(self, x, y):
|
||||
sum_value = self.zero_float
|
||||
num_actor = 0
|
||||
while num_actor < 3:
|
||||
sum_value += self.net_list[num_actor].act(x, y)
|
||||
num_actor += 1
|
||||
return sum_value
|
||||
|
||||
|
||||
def test_list_item_getattr2():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell with a Tensor variable.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = [Actor(), Actor(), Actor()]
|
||||
trainer = Trainer2(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
Loading…
Reference in New Issue