!40842 Revert getattr_getitem pass

Merge pull request !40842 from huangbingjian/revert_pass
This commit is contained in:
i-robot 2022-08-25 03:26:18 +00:00 committed by Gitee
commit b66f06efd6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 181 additions and 15 deletions

View File

@ -30,6 +30,11 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
auto object_node = object.GetNode(node);
auto attr_node = attr.GetNode(node);
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
if (parse::IsGetItemCNode(object_node)) {
return parse::ResolveGetItemWithAttr(optimizer->manager(), object_node, attr_node, node);
}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) {
// node is get_attr node

View File

@ -34,6 +34,7 @@ namespace irpass {
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
// The same is true for matching Resolve pattern.
//
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, MsClassObject, attr}

View File

@ -73,6 +73,9 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_IS_CELL_LIST[] = "is_cell_list";
const char PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE[] = "convert_cell_list_to_sequence";
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR[] = "convert_to_ms_csrtensor";
const char PYTHON_MOD_CONVERT_TO_MS_COOTENSOR[] = "convert_to_ms_cootensor";

View File

@ -527,6 +527,104 @@ AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNo
return nullptr;
}
// Get python object with index from a list or the whole list if the index is not fixed.
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
py::object obj = GetSymbolObject(name_space, symbol, node);
// If obj is nn.CellList, convert it to sequence.
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
bool is_celllist = py::cast<bool>(python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CELL_LIST, obj));
if (is_celllist) {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE, obj);
}
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
return py::none();
}
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
if (imm_value == nullptr) {
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
<< ", index_node: " << index_node->DebugString();
// Index is not fixed, return the whole list.
return obj;
}
// It index is a value node, get the item of index directly.
py::object item_obj =
python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_ITEM_FROM_SEQUENCE, obj, py::int_(imm_value->value()));
return item_obj;
}
bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
// Check if the node matches: {prim::kPrim::Resolve, ..., 'getitem'}.
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
constexpr size_t symbol_index = 2;
constexpr auto getitem_symbol = "getitem";
auto cnode = node->cast<CNodePtr>();
auto symbol = GetValueNode<SymbolPtr>(cnode->input(symbol_index));
return symbol->symbol() == getitem_symbol;
}
return false;
}
bool IsGetItemCNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
constexpr size_t getitem_inputs_size = 3;
if (cnode->size() != getitem_inputs_size) {
return false;
}
constexpr auto prim_index = 0;
return IsResolveNodeWithGetItem(cnode->input(prim_index));
}
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
constexpr auto data_index = 1;
constexpr auto index_index = 2;
auto getitem_cnode = getitem_node->cast<CNodePtr>();
auto data_node = getitem_cnode->input(data_index);
auto index_node = getitem_cnode->input(index_index);
if (IsPrimitiveCNode(data_node, prim::kPrimResolve)) {
auto [name_space, symbol] = GetNamespaceAndSymbol(data_node);
auto obj = GetObjectFromSequence(name_space, symbol, data_node, index_node);
if (py::isinstance<py::none>(obj)) {
return nullptr;
}
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
return ResolveSequenceWithAttr(manager, obj, data_node, attr_node, getitem_cnode);
}
return ResolveCellWithAttr(manager, obj, data_node, attr_node, node);
}
if (IsPrimitiveCNode(data_node, prim::kPrimGetAttr)) {
auto getattr_cnode = data_node->cast<CNodePtr>();
auto resolve_node = getattr_cnode->input(data_index);
auto member_node = getattr_cnode->input(index_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
// Check if the result is a new resolve node.
auto item_node = ResolveSymbolWithAttr(manager, resolve_node, member_node, node);
if (IsPrimitiveCNode(item_node, prim::kPrimResolve)) {
auto [name_space, symbol] = GetNamespaceAndSymbol(item_node);
auto obj = GetObjectFromSequence(name_space, symbol, item_node, index_node);
if (py::isinstance<py::none>(obj)) {
return nullptr;
}
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
return ResolveSequenceWithAttr(manager, obj, item_node, attr_node, getitem_cnode);
}
return ResolveCellWithAttr(manager, obj, item_node, attr_node, node);
}
}
}
return nullptr;
}
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
const std::string &attr, const AnfNodePtr &get_attr_node) {
// Get attribute or method from ms_class obj.

View File

@ -157,9 +157,14 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
const AnfNodePtr &node);
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node);
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node);
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
const std::string &attr, const AnfNodePtr &node);
// Check if node is cnode with getitem.
bool IsGetItemCNode(const AnfNodePtr &node);
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);

View File

@ -1691,10 +1691,12 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
}
}
// Get attribute or method of class object decorated by ms_class.
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
}
// Get attribute or method of nn.Cell object.
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
if (data_func_graph != nullptr) {
auto res = GetEvaluatedValueForCellAttrOrMethod(args_spec_list, data_func_graph->func_graph(), out_conf);

View File

@ -25,7 +25,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor, convert_class_to_function)
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
get_obj_from_sequence)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -35,4 +36,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
'convert_to_ms_cootensor', 'convert_class_to_function']
'convert_to_ms_cootensor', 'convert_class_to_function', 'convert_cell_list_to_sequence', 'is_cell_list',
'get_obj_from_sequence']

View File

@ -459,6 +459,29 @@ def ms_isinstance(x, cmp_type):
return isinstance(x, pytype_to_mstype.get(cmp_type))
def is_cell_list(obj):
"""Check if obj is nn.CellList"""
return isinstance(obj, nn.CellList)
def convert_cell_list_to_sequence(obj):
"""Convert nn.CellList to sequence."""
if not isinstance(obj, nn.CellList):
raise TypeError(f"Obj should be nn.CellList, but got {obj}")
if not hasattr(obj, "_cells"):
raise AttributeError(f"nn.CellList is missing _cells property.")
cells = getattr(obj, "_cells")
return list(cells.values())
def get_obj_from_sequence(obj, index):
"""Implement `tuple_getitem`."""
if not isinstance(obj, (tuple, list)):
raise TypeError(f"Should not get item from a object that not sequence type, obj: {obj}")
# Not check index out of range by self.
return obj[index]
def get_module_namespace(obj):
"""Get the module's namespace."""
logger.debug("get module namespace, module: %r", obj)

View File

@ -17,6 +17,7 @@ import pytest
import numpy as np
from mindspore import context, nn, dtype, Tensor, ms_function, ms_class
from mindspore.ops import operations as P
from mindspore.ops import composite as C
class Actor(nn.Cell):
@ -39,13 +40,27 @@ class Trainer(nn.Cell):
return self.net_list[0].act(x, y)
def verify_list_item_getattr(trainer, expect_res):
class GradNet(nn.Cell):
def __init__(self, network, get_all=False, get_by_list=False, sens_param=False):
super(GradNet, self).__init__()
self.network = network
self.grad = C.GradOperation(get_all, get_by_list, sens_param)
def construct(self, *inputs):
grads = self.grad(self.network)(*inputs)
return grads
def verify_list_item_getattr(trainer, expect_res, expect_grad_res):
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
grad_net = GradNet(trainer)
res2 = grad_net(x, y)
assert np.array_equal(res2.asnumpy(), expect_grad_res.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@ -63,7 +78,8 @@ def test_list_item_getattr():
actor_list = [Actor()]
trainer = Trainer(actor_list)
expect_res = Tensor([9], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
expect_grad_res = Tensor([1], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
@pytest.mark.level0
@ -84,7 +100,8 @@ def test_cell_list_getattr():
actor_list.append(Actor())
trainer = Trainer(actor_list)
expect_res = Tensor([9], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
expect_grad_res = Tensor([1], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
@pytest.mark.level0
@ -103,7 +120,8 @@ def test_msclass_list_getattr():
actor_list = [Actor2()]
trainer = Trainer(actor_list)
expect_res = Tensor([9], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
expect_grad_res = Tensor([1], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
class Trainer2(nn.Cell):
@ -138,7 +156,8 @@ def test_list_item_getattr2():
actor_list = [Actor(), Actor(), Actor()]
trainer = Trainer2(actor_list)
expect_res = Tensor([27], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
expect_grad_res = Tensor([3], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
@pytest.mark.level0
@ -159,7 +178,8 @@ def test_cell_list_getattr2():
actor_list.append(Actor())
trainer = Trainer2(actor_list)
expect_res = Tensor([27], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
expect_grad_res = Tensor([3], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
@pytest.mark.level0
@ -178,7 +198,8 @@ def test_msclass_list_getattr2():
actor_list = [Actor2(), Actor2(), Actor2()]
trainer = Trainer2(actor_list)
expect_res = Tensor([27], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
expect_grad_res = Tensor([3], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
class MSRL(nn.Cell):
@ -212,13 +233,16 @@ class Trainer3(nn.Cell):
return output
def verify_list_item_getattr2(trainer, expect_res):
def verify_list_item_getattr2(trainer, expect_res, expect_grad_res):
x = Tensor([2], dtype=dtype.int32)
y = Tensor([3], dtype=dtype.int32)
res = trainer.test(x, y)
print(f'res: {res}')
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
grad_net = GradNet(trainer)
res2 = grad_net(x, y)
assert np.array_equal(res2.asnumpy(), expect_grad_res.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@ -240,7 +264,8 @@ def test_list_item_getattr3():
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
expect_res = Tensor([15], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res)
expect_grad_res = Tensor([3], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res, expect_grad_res)
@pytest.mark.level0
@ -263,7 +288,8 @@ def test_cell_list_getattr3():
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
expect_res = Tensor([15], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res)
expect_grad_res = Tensor([3], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res, expect_grad_res)
@pytest.mark.level0
@ -286,4 +312,5 @@ def test_msclass_list_getattr3():
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
expect_res = Tensor([15], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res)
expect_grad_res = Tensor([3], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res, expect_grad_res)