forked from mindspore-Ecosystem/mindspore
Support variable index of getitem during handling cell list getattr.
This commit is contained in:
parent
9558ba49d8
commit
ea8c47e981
|
@ -36,45 +36,41 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
||||||
constexpr auto recursive_level = 3;
|
constexpr auto recursive_level = 3;
|
||||||
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
|
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
|
||||||
|
|
||||||
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}}
|
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, index}, attr}
|
||||||
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
|
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
|
||||||
if (getitem_cnode != nullptr) {
|
constexpr size_t getitem_inputs_size = 3;
|
||||||
|
if (getitem_cnode != nullptr && getitem_cnode->size() == getitem_inputs_size) {
|
||||||
constexpr size_t prim_index = 0;
|
constexpr size_t prim_index = 0;
|
||||||
auto primitive_node = getitem_cnode->input(prim_index);
|
auto resolve_getitem_node = getitem_cnode->input(prim_index);
|
||||||
auto resolved_getitem_node = primitive_node;
|
constexpr size_t resolve_index = 1;
|
||||||
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) {
|
auto resolve_node = getitem_cnode->input(resolve_index);
|
||||||
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>();
|
if (IsPrimitiveCNode(resolve_getitem_node, prim::kPrimResolve) &&
|
||||||
|
IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||||
|
auto resolve_getitem_cnode = resolve_getitem_node->cast<CNodePtr>();
|
||||||
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
|
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
|
||||||
constexpr auto getitem_symbol = "getitem";
|
constexpr auto getitem_symbol = "getitem";
|
||||||
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
|
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
|
||||||
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
|
|
||||||
resolved_getitem_node =
|
|
||||||
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool is_getattr_getitem = false;
|
|
||||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
|
|
||||||
if (do_signature != nullptr) {
|
|
||||||
auto &func_value = do_signature->function();
|
|
||||||
// The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
|
|
||||||
auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
|
|
||||||
constexpr auto getitem_symbol = "getitem";
|
|
||||||
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
|
|
||||||
is_getattr_getitem = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
|
|
||||||
is_getattr_getitem = true;
|
|
||||||
}
|
|
||||||
if (is_getattr_getitem) {
|
|
||||||
constexpr size_t resolve_index = 1;
|
|
||||||
auto resolve_node = getitem_cnode->input(resolve_index);
|
|
||||||
constexpr size_t position_index = 2;
|
constexpr size_t position_index = 2;
|
||||||
auto index_node = getitem_cnode->input(position_index);
|
auto index_node = getitem_cnode->input(position_index);
|
||||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve) && index_node->isa<ValueNode>()) {
|
|
||||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
||||||
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||||
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr);
|
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
auto sequence = obj.cast<py::sequence>();
|
||||||
|
for (size_t i = 0; i < sequence.size(); ++i) {
|
||||||
|
auto res = parse::ResolveCellWithAttr(optimizer->manager(), sequence[i], resolve_node, attr);
|
||||||
|
inputs.emplace_back(res);
|
||||||
|
}
|
||||||
|
auto make_tuple_node = getitem_cnode->func_graph()->NewCNodeInOrder(inputs);
|
||||||
|
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
|
||||||
|
auto resolved_getitem_node =
|
||||||
|
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
|
||||||
|
auto out =
|
||||||
|
getitem_cnode->func_graph()->NewCNodeInOrder({resolved_getitem_node, make_tuple_node, index_node});
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -311,8 +311,8 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Get python object with index from a list.
|
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||||
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||||
const AnfNodePtr &index_node) {
|
const AnfNodePtr &index_node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||||
|
@ -329,15 +329,18 @@ py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const Symbo
|
||||||
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
|
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
|
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
|
||||||
const std::string module = "mindspore._extends.parse.parser";
|
|
||||||
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
|
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
|
||||||
if (imm_value == nullptr) {
|
if (imm_value == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Expect an int64 value node, node: " << node->DebugString()
|
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
|
||||||
<< ", index_node: " << index_node->DebugString();
|
<< ", index_node: " << index_node->DebugString();
|
||||||
|
// Index is not fixed, return the whole list.
|
||||||
|
return obj;
|
||||||
}
|
}
|
||||||
|
// It index is a value node, get the item of index directly.
|
||||||
|
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
|
||||||
|
const std::string module = "mindspore._extends.parse.parser";
|
||||||
int index = imm_value->value();
|
int index = imm_value->value();
|
||||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
|
|
||||||
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
|
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
|
||||||
return item_obj;
|
return item_obj;
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,8 +179,8 @@ class SymbolResolver {
|
||||||
};
|
};
|
||||||
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
|
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
|
||||||
|
|
||||||
// Get python object with index from a list.
|
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||||
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||||
const AnfNodePtr &index_node);
|
const AnfNodePtr &index_node);
|
||||||
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
|
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test a list of cell, and getattr by its item """
|
""" test a list of cell, and getattr by its item """
|
||||||
|
import numpy as np
|
||||||
from mindspore import context, nn, dtype, Tensor
|
from mindspore import context, nn, dtype, Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Cell):
|
class Actor(nn.Cell):
|
||||||
def __init__(self):
|
|
||||||
super(Actor, self).__init__()
|
|
||||||
|
|
||||||
def act(self, x, y):
|
def act(self, x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
|
@ -44,4 +43,40 @@ def test_list_item_getattr():
|
||||||
trainer = Trainer(actor_list)
|
trainer = Trainer(actor_list)
|
||||||
x = Tensor([3], dtype=dtype.float32)
|
x = Tensor([3], dtype=dtype.float32)
|
||||||
y = Tensor([6], dtype=dtype.float32)
|
y = Tensor([6], dtype=dtype.float32)
|
||||||
print(trainer(x, y))
|
res = trainer(x, y)
|
||||||
|
print(f'res: {res}')
|
||||||
|
expect_res = Tensor([9], dtype=dtype.float32)
|
||||||
|
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer2(nn.Cell):
|
||||||
|
def __init__(self, net_list):
|
||||||
|
super(Trainer2, self).__init__()
|
||||||
|
self.net_list = net_list
|
||||||
|
self.less = P.Less()
|
||||||
|
self.zero_float = Tensor(0, dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
sum_value = self.zero_float
|
||||||
|
num_actor = 0
|
||||||
|
while num_actor < 3:
|
||||||
|
sum_value += self.net_list[num_actor].act(x, y)
|
||||||
|
num_actor += 1
|
||||||
|
return sum_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_item_getattr2():
|
||||||
|
"""
|
||||||
|
Feature: getattr by the item from list of cell with a Tensor variable.
|
||||||
|
Description: Support RL use method in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
actor_list = [Actor(), Actor(), Actor()]
|
||||||
|
trainer = Trainer2(actor_list)
|
||||||
|
x = Tensor([3], dtype=dtype.float32)
|
||||||
|
y = Tensor([6], dtype=dtype.float32)
|
||||||
|
res = trainer(x, y)
|
||||||
|
print(f'res: {res}')
|
||||||
|
expect_res = Tensor([27], dtype=dtype.float32)
|
||||||
|
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
Loading…
Reference in New Issue