forked from mindspore-Ecosystem/mindspore
!40842 Revert getattr_getitem pass
Merge pull request !40842 from huangbingjian/revert_pass
This commit is contained in:
commit
b66f06efd6
|
@ -30,6 +30,11 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
auto object_node = object.GetNode(node);
|
||||
auto attr_node = attr.GetNode(node);
|
||||
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
|
||||
if (parse::IsGetItemCNode(object_node)) {
|
||||
return parse::ResolveGetItemWithAttr(optimizer->manager(), object_node, attr_node, node);
|
||||
}
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) {
|
||||
// node is get_attr node
|
||||
|
|
|
@ -34,6 +34,7 @@ namespace irpass {
|
|||
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
|
||||
// The same is true for matching Resolve pattern.
|
||||
//
|
||||
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, MsClassObject, attr}
|
||||
|
|
|
@ -73,6 +73,9 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
|
|||
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||
const char PYTHON_MOD_IS_CELL_LIST[] = "is_cell_list";
|
||||
const char PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE[] = "convert_cell_list_to_sequence";
|
||||
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR[] = "convert_to_ms_csrtensor";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_COOTENSOR[] = "convert_to_ms_cootensor";
|
||||
|
|
|
@ -527,6 +527,104 @@ AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNo
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
py::object obj = GetSymbolObject(name_space, symbol, node);
|
||||
// If obj is nn.CellList, convert it to sequence.
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
bool is_celllist = py::cast<bool>(python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CELL_LIST, obj));
|
||||
if (is_celllist) {
|
||||
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE, obj);
|
||||
}
|
||||
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
|
||||
return py::none();
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
|
||||
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
|
||||
if (imm_value == nullptr) {
|
||||
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
|
||||
<< ", index_node: " << index_node->DebugString();
|
||||
// Index is not fixed, return the whole list.
|
||||
return obj;
|
||||
}
|
||||
// It index is a value node, get the item of index directly.
|
||||
py::object item_obj =
|
||||
python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_ITEM_FROM_SEQUENCE, obj, py::int_(imm_value->value()));
|
||||
return item_obj;
|
||||
}
|
||||
|
||||
bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
|
||||
// Check if the node matches: {prim::kPrim::Resolve, ..., 'getitem'}.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
|
||||
constexpr size_t symbol_index = 2;
|
||||
constexpr auto getitem_symbol = "getitem";
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto symbol = GetValueNode<SymbolPtr>(cnode->input(symbol_index));
|
||||
return symbol->symbol() == getitem_symbol;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsGetItemCNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
constexpr size_t getitem_inputs_size = 3;
|
||||
if (cnode->size() != getitem_inputs_size) {
|
||||
return false;
|
||||
}
|
||||
constexpr auto prim_index = 0;
|
||||
return IsResolveNodeWithGetItem(cnode->input(prim_index));
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
|
||||
constexpr auto data_index = 1;
|
||||
constexpr auto index_index = 2;
|
||||
auto getitem_cnode = getitem_node->cast<CNodePtr>();
|
||||
auto data_node = getitem_cnode->input(data_index);
|
||||
auto index_node = getitem_cnode->input(index_index);
|
||||
if (IsPrimitiveCNode(data_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = GetNamespaceAndSymbol(data_node);
|
||||
auto obj = GetObjectFromSequence(name_space, symbol, data_node, index_node);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
return ResolveSequenceWithAttr(manager, obj, data_node, attr_node, getitem_cnode);
|
||||
}
|
||||
return ResolveCellWithAttr(manager, obj, data_node, attr_node, node);
|
||||
}
|
||||
if (IsPrimitiveCNode(data_node, prim::kPrimGetAttr)) {
|
||||
auto getattr_cnode = data_node->cast<CNodePtr>();
|
||||
auto resolve_node = getattr_cnode->input(data_index);
|
||||
auto member_node = getattr_cnode->input(index_index);
|
||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
// Check if the result is a new resolve node.
|
||||
auto item_node = ResolveSymbolWithAttr(manager, resolve_node, member_node, node);
|
||||
if (IsPrimitiveCNode(item_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = GetNamespaceAndSymbol(item_node);
|
||||
auto obj = GetObjectFromSequence(name_space, symbol, item_node, index_node);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
return ResolveSequenceWithAttr(manager, obj, item_node, attr_node, getitem_cnode);
|
||||
}
|
||||
return ResolveCellWithAttr(manager, obj, item_node, attr_node, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
|
||||
const std::string &attr, const AnfNodePtr &get_attr_node) {
|
||||
// Get attribute or method from ms_class obj.
|
||||
|
|
|
@ -157,9 +157,14 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
|||
const AnfNodePtr &node);
|
||||
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node);
|
||||
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node);
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
|
||||
const std::string &attr, const AnfNodePtr &node);
|
||||
|
||||
// Check if node is cnode with getitem.
|
||||
bool IsGetItemCNode(const AnfNodePtr &node);
|
||||
|
||||
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
|
||||
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);
|
||||
|
||||
|
|
|
@ -1691,10 +1691,12 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
}
|
||||
}
|
||||
|
||||
// Get attribute or method of class object decorated by ms_class.
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
|
||||
}
|
||||
// Get attribute or method of nn.Cell object.
|
||||
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
|
||||
if (data_func_graph != nullptr) {
|
||||
auto res = GetEvaluatedValueForCellAttrOrMethod(args_spec_list, data_func_graph->func_graph(), out_conf);
|
||||
|
|
|
@ -25,7 +25,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||
convert_to_ms_cootensor, convert_class_to_function)
|
||||
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
|
||||
get_obj_from_sequence)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
@ -35,4 +36,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
|||
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
|
||||
'convert_to_ms_cootensor', 'convert_class_to_function']
|
||||
'convert_to_ms_cootensor', 'convert_class_to_function', 'convert_cell_list_to_sequence', 'is_cell_list',
|
||||
'get_obj_from_sequence']
|
||||
|
|
|
@ -459,6 +459,29 @@ def ms_isinstance(x, cmp_type):
|
|||
return isinstance(x, pytype_to_mstype.get(cmp_type))
|
||||
|
||||
|
||||
def is_cell_list(obj):
|
||||
"""Check if obj is nn.CellList"""
|
||||
return isinstance(obj, nn.CellList)
|
||||
|
||||
|
||||
def convert_cell_list_to_sequence(obj):
|
||||
"""Convert nn.CellList to sequence."""
|
||||
if not isinstance(obj, nn.CellList):
|
||||
raise TypeError(f"Obj should be nn.CellList, but got {obj}")
|
||||
if not hasattr(obj, "_cells"):
|
||||
raise AttributeError(f"nn.CellList is missing _cells property.")
|
||||
cells = getattr(obj, "_cells")
|
||||
return list(cells.values())
|
||||
|
||||
|
||||
def get_obj_from_sequence(obj, index):
|
||||
"""Implement `tuple_getitem`."""
|
||||
if not isinstance(obj, (tuple, list)):
|
||||
raise TypeError(f"Should not get item from a object that not sequence type, obj: {obj}")
|
||||
# Not check index out of range by self.
|
||||
return obj[index]
|
||||
|
||||
|
||||
def get_module_namespace(obj):
|
||||
"""Get the module's namespace."""
|
||||
logger.debug("get module namespace, module: %r", obj)
|
||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
|||
import numpy as np
|
||||
from mindspore import context, nn, dtype, Tensor, ms_function, ms_class
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
class Actor(nn.Cell):
|
||||
|
@ -39,13 +40,27 @@ class Trainer(nn.Cell):
|
|||
return self.net_list[0].act(x, y)
|
||||
|
||||
|
||||
def verify_list_item_getattr(trainer, expect_res):
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, network, get_all=False, get_by_list=False, sens_param=False):
|
||||
super(GradNet, self).__init__()
|
||||
self.network = network
|
||||
self.grad = C.GradOperation(get_all, get_by_list, sens_param)
|
||||
|
||||
def construct(self, *inputs):
|
||||
grads = self.grad(self.network)(*inputs)
|
||||
return grads
|
||||
|
||||
|
||||
def verify_list_item_getattr(trainer, expect_res, expect_grad_res):
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
grad_net = GradNet(trainer)
|
||||
res2 = grad_net(x, y)
|
||||
assert np.array_equal(res2.asnumpy(), expect_grad_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -63,7 +78,8 @@ def test_list_item_getattr():
|
|||
actor_list = [Actor()]
|
||||
trainer = Trainer(actor_list)
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
expect_grad_res = Tensor([1], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -84,7 +100,8 @@ def test_cell_list_getattr():
|
|||
actor_list.append(Actor())
|
||||
trainer = Trainer(actor_list)
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
expect_grad_res = Tensor([1], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -103,7 +120,8 @@ def test_msclass_list_getattr():
|
|||
actor_list = [Actor2()]
|
||||
trainer = Trainer(actor_list)
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
expect_grad_res = Tensor([1], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
class Trainer2(nn.Cell):
|
||||
|
@ -138,7 +156,8 @@ def test_list_item_getattr2():
|
|||
actor_list = [Actor(), Actor(), Actor()]
|
||||
trainer = Trainer2(actor_list)
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
expect_grad_res = Tensor([3], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -159,7 +178,8 @@ def test_cell_list_getattr2():
|
|||
actor_list.append(Actor())
|
||||
trainer = Trainer2(actor_list)
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
expect_grad_res = Tensor([3], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -178,7 +198,8 @@ def test_msclass_list_getattr2():
|
|||
actor_list = [Actor2(), Actor2(), Actor2()]
|
||||
trainer = Trainer2(actor_list)
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
expect_grad_res = Tensor([3], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
class MSRL(nn.Cell):
|
||||
|
@ -212,13 +233,16 @@ class Trainer3(nn.Cell):
|
|||
return output
|
||||
|
||||
|
||||
def verify_list_item_getattr2(trainer, expect_res):
|
||||
def verify_list_item_getattr2(trainer, expect_res, expect_grad_res):
|
||||
x = Tensor([2], dtype=dtype.int32)
|
||||
y = Tensor([3], dtype=dtype.int32)
|
||||
res = trainer.test(x, y)
|
||||
print(f'res: {res}')
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
grad_net = GradNet(trainer)
|
||||
res2 = grad_net(x, y)
|
||||
assert np.array_equal(res2.asnumpy(), expect_grad_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -240,7 +264,8 @@ def test_list_item_getattr3():
|
|||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res)
|
||||
expect_grad_res = Tensor([3], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -263,7 +288,8 @@ def test_cell_list_getattr3():
|
|||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res)
|
||||
expect_grad_res = Tensor([3], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res, expect_grad_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -286,4 +312,5 @@ def test_msclass_list_getattr3():
|
|||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res)
|
||||
expect_grad_res = Tensor([3], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res, expect_grad_res)
|
||||
|
|
Loading…
Reference in New Issue