Support getattr of list of CellList or MsClass list.

This commit is contained in:
huangbingjian 2022-04-11 11:34:08 +08:00
parent 517ece4ac8
commit 198cabb999
10 changed files with 299 additions and 146 deletions

View File

@ -31,35 +31,15 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
auto object_node = object.GetNode(node);
auto attr_node = attr.GetNode(node);
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
if (parse::IsGetItemCNode(object_node)) {
return parse::ResolveGetItemWithAttr(optimizer->manager(), object_node, attr_node, node);
}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) {
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(object_node);
auto module_name = name_space->module();
constexpr std::string_view parse_super_name = "namespace";
if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
symbol->symbol() != parse_super_name) {
auto symbol_obj = parse::GetSymbolObject(name_space, symbol, node);
return parse::ResolveCellWithAttr(optimizer->manager(), symbol_obj, object_node, attr_node);
}
return parse::ResolveSymbolWithAttr(optimizer->manager(), object_node, attr_node, node);
}
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
if (parse::IsGetItemCNode(object_node)) {
auto getitem_cnode = object_node->cast<CNodePtr>();
constexpr auto resolve_index = 1;
constexpr auto index_index = 2;
auto resolve_node = getitem_cnode->input(resolve_index);
auto index_node = getitem_cnode->input(index_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr_node, getitem_cnode);
}
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr_node);
}
}
// {prim::kPrimGetAttr, namespace, attr}
if (IsValueNode<parse::NameSpace>(object_node)) {
auto name_space = GetValueNode<parse::NameSpacePtr>(object_node);
@ -67,14 +47,12 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(attr_str);
return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node);
}
// {prim::kPrimGetAttr, MsClassObject, attr}
if (IsValueNode<parse::MsClassObject>(object_node)) {
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node);
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node)->obj();
auto attr_str = GetValue<std::string>(GetValueNode(attr_node));
return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
}
// {prim::kPrimGetAttr, bool, attr}
if (IsValueNode<BoolImm>(object_node)) {
return object_node;

View File

@ -668,6 +668,9 @@ bool IsCellInstance(const py::object &obj) {
return is_cell;
}
// Check if the object is MsClass instance.
bool IsMsClassInstance(const py::object &obj) { return py::hasattr(obj, PYTHON_MS_CLASS); }
// Check if the object is class type.
bool IsClassType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);

View File

@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
bool IsCellInstance(const py::object &obj);
bool IsMsClassInstance(const py::object &obj);
bool IsClassType(const py::object &obj);
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);

View File

@ -75,6 +75,8 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_IS_CELL_LIST[] = "is_cell_list";
const char PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE[] = "convert_cell_list_to_sequence";
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";

View File

@ -376,71 +376,21 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
}
} // namespace
// Get python object with index from a list or the whole list if the index is not fixed.
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
py::object obj = GetSymbolObject(name_space, symbol, node);
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
}
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
if (imm_value == nullptr) {
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
<< ", index_node: " << index_node->DebugString();
// Index is not fixed, return the whole list.
return obj;
}
// It index is a value node, get the item of index directly.
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
const std::string module = "mindspore._extends.parse.parser";
auto index = imm_value->value();
py::object item_obj = python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
return item_obj;
}
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
const CNodePtr &operand_cnode) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto sequence = obj.cast<py::sequence>();
// Incorporate if all elements of the sequence are Cell instances.
for (size_t i = 0; i < sequence.size(); ++i) {
if (!parse::data_converter::IsCellInstance(sequence[i])) {
return nullptr;
}
// Resolve Cell instance.
auto res = parse::ResolveCellWithAttr(manager, sequence[i], resolve_node, attr);
inputs.emplace_back(res);
}
constexpr auto prim_index = 0;
constexpr auto index_index = 2;
auto fg = operand_cnode->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto make_tuple_node = fg->NewCNodeInOrder(inputs);
return fg->NewCNodeInOrder({operand_cnode->input(prim_index), make_tuple_node, operand_cnode->input(index_index)});
}
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
std::pair<NameSpacePtr, SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
auto resolve_cnode = node->cast<CNodePtr>();
constexpr size_t namespace_index = 1;
auto namespace_node = resolve_cnode->input(namespace_index);
constexpr size_t symbol_index = 2;
auto symbol_node = resolve_cnode->input(symbol_index);
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
if (!IsValueNode<NameSpace>(namespace_node) || !IsValueNode<Symbol>(symbol_node)) {
MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
<< ", symbol: " << symbol_node->ToString();
}
// Deal with the case of GetAttr from a class member,
// and avoid the case of GetAttr from self (the result of ParseSuper).
auto name_space = GetValueNode<parse::NameSpacePtr>(namespace_node);
auto symbol = GetValueNode<parse::SymbolPtr>(symbol_node);
auto name_space = GetValueNode<NameSpacePtr>(namespace_node);
auto symbol = GetValueNode<SymbolPtr>(symbol_node);
return {name_space, symbol};
}
constexpr auto recursive_level = 2;
@ -491,9 +441,8 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
return res_node;
}
const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
const std::string module = "mindspore._extends.parse.parser";
py::object namespace_obj = python_adapter::GetPyFn(module, fn)(obj);
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
@ -506,13 +455,99 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
return resolved_node;
}
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
const CNodePtr &operand_cnode) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto sequence = obj.cast<py::sequence>();
// Incorporate if all elements of the sequence are Cell instances or MsClass instances.
size_t count_cell = 0;
size_t count_msclass = 0;
size_t sequence_size = sequence.size();
for (size_t i = 0; i < sequence_size; ++i) {
if (data_converter::IsCellInstance(sequence[i])) {
++count_cell;
} else if (data_converter::IsMsClassInstance(sequence[i])) {
++count_msclass;
}
}
if (count_cell == sequence_size) {
// Resolve Cell instances.
for (size_t i = 0; i < sequence_size; ++i) {
auto res = ResolveCellWithAttr(manager, sequence[i], resolve_node, attr);
inputs.emplace_back(res);
}
} else if (count_msclass == sequence_size) {
// Resolve MsClass instances.
for (size_t i = 0; i < sequence_size; ++i) {
auto attr_str = GetValue<std::string>(GetValueNode(attr));
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, operand_cnode);
inputs.emplace_back(res);
}
} else {
return nullptr;
}
constexpr auto prim_index = 0;
constexpr auto index_index = 2;
auto fg = operand_cnode->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto make_tuple_node = fg->NewCNodeInOrder(inputs);
return fg->NewCNodeInOrder({operand_cnode->input(prim_index), make_tuple_node, operand_cnode->input(index_index)});
}
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
auto [name_space, symbol] = GetNamespaceAndSymbol(object_node);
auto module_name = name_space->module();
constexpr std::string_view parse_super_name = "namespace";
if (module_name.find(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
symbol->symbol() != parse_super_name) {
auto symbol_obj = GetSymbolObject(name_space, symbol, node);
return ResolveCellWithAttr(manager, symbol_obj, object_node, attr_node);
}
return nullptr;
}
// Get python object with index from a list or the whole list if the index is not fixed.
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
py::object obj = GetSymbolObject(name_space, symbol, node);
// If obj is nn.CellList, convert it to sequence.
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
bool is_celllist = py::cast<bool>(python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CELL_LIST, obj));
if (is_celllist) {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE, obj);
}
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
}
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
if (imm_value == nullptr) {
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
<< ", index_node: " << index_node->DebugString();
// Index is not fixed, return the whole list.
return obj;
}
// It index is a value node, get the item of index directly.
py::object item_obj =
python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_ITEM_FROM_SEQUENCE, obj, py::int_(imm_value->value()));
return item_obj;
}
bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
// Check if the node matches: {prim::kPrim::Resolve, ..., 'getitem'}.
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
constexpr size_t symbol_index = 2;
constexpr auto getitem_symbol = "getitem";
auto cnode = node->cast<CNodePtr>();
auto symbol = GetValueNode<parse::SymbolPtr>(cnode->input(symbol_index));
auto symbol = GetValueNode<SymbolPtr>(cnode->input(symbol_index));
return symbol->symbol() == getitem_symbol;
}
return false;
@ -531,21 +566,58 @@ bool IsGetItemCNode(const AnfNodePtr &node) {
return IsResolveNodeWithGetItem(cnode->input(prim_index));
}
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class,
AnfNodePtr ResolveGetItemInner(const FuncGraphManagerPtr &manager, const AnfNodePtr &data_node,
const AnfNodePtr &index_node, const CNodePtr &getitem_cnode,
const AnfNodePtr &attr_node) {
auto [name_space, symbol] = GetNamespaceAndSymbol(data_node);
auto obj = GetObjectFromSequence(name_space, symbol, data_node, index_node);
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
return ResolveSequenceWithAttr(manager, obj, data_node, attr_node, getitem_cnode);
}
return ResolveCellWithAttr(manager, obj, data_node, attr_node);
}
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
constexpr auto data_index = 1;
constexpr auto index_index = 2;
auto getitem_cnode = getitem_node->cast<CNodePtr>();
auto data_node = getitem_cnode->input(data_index);
auto index_node = getitem_cnode->input(index_index);
if (IsPrimitiveCNode(data_node, prim::kPrimResolve)) {
return ResolveGetItemInner(manager, data_node, index_node, getitem_cnode, attr_node);
}
if (IsPrimitiveCNode(data_node, prim::kPrimGetAttr)) {
auto getattr_cnode = data_node->cast<CNodePtr>();
auto resolve_node = getattr_cnode->input(data_index);
auto member_node = getattr_cnode->input(index_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
// Check if the result is a new resolve node.
auto item_node = ResolveSymbolWithAttr(manager, resolve_node, member_node, node);
if (IsPrimitiveCNode(item_node, prim::kPrimResolve)) {
return ResolveGetItemInner(manager, item_node, index_node, getitem_cnode, attr_node);
}
}
}
return nullptr;
}
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
const std::string &attr, const AnfNodePtr &node) {
// Get attribute or method from ms_class obj.
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
MS_LOG(DEBUG) << "Resolve ms_class obj (" << py::str(cls_obj) << ") with attr " << attr << ".";
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
constexpr size_t prefix_index = 0;
if (attr.size() > 0 && attr[prefix_index] == '_') {
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
}
py::object cls_obj = ms_class->obj();
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
MS_LOG(EXCEPTION) << py::str(cls_obj) << " has not attribute: " << attr << ".";
}
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);

View File

@ -163,23 +163,14 @@ class ClassType final : public PyObjectWrapper {
};
using ClassTypePtr = std::shared_ptr<ClassType>;
// Get python object with index from a list or the whole list if the index is not fixed.
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node);
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
// Get resolved python object by namespace and symbol.
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node);
// Resolve symbol in namespace.
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node);
// Resolve Cell with attr name.
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
const AnfNodePtr &attr);
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
const CNodePtr &operand_cnode);
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class,
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node);
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
const AnfNodePtr &attr_node, const AnfNodePtr &node);
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &ms_class,
const std::string &attr, const AnfNodePtr &node);
// Check if node is cnode with getitem.

View File

@ -1355,7 +1355,7 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &e
// Get the attr/method of ms_class object.
auto out_node = out_conf->node();
FuncGraphPtr func_graph = out_node->func_graph();
auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node);
auto new_node = parse::ResolveMsClassWithAttr(func_graph->manager(), ms_class->obj(), item_name, out_node);
// Replace old node with the resolved new node in order list.
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();

View File

@ -23,8 +23,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, get_dataclass_attributes, get_dataclass_methods, check_obj_bool,
python_isinstance, ms_isinstance)
is_class_type, get_dataclass_attributes, get_dataclass_methods, check_obj_bool, is_cell_list,
python_isinstance, ms_isinstance, convert_cell_list_to_sequence, get_obj_from_sequence)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -33,5 +33,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods', 'check_obj_bool', 'python_isinstance',
'ms_isinstance']
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods', 'check_obj_bool', 'is_cell_list',
'python_isinstance', 'ms_isinstance', 'convert_cell_list_to_sequence', 'get_obj_from_sequence']

View File

@ -426,6 +426,21 @@ def ms_isinstance(x, cmp_type):
return isinstance(x, pytype_to_mstype.get(cmp_type))
def is_cell_list(obj):
"""Check if obj is nn.CellList"""
return isinstance(obj, nn.CellList)
def convert_cell_list_to_sequence(obj):
"""Convert nn.CellList to sequence."""
if not isinstance(obj, nn.CellList):
raise TypeError(f"Obj should be nn.CellList, but got {obj}")
if not hasattr(obj, "_cells"):
raise AttributeError(f"nn.CellList is missing _cells property.")
cells = getattr(obj, "_cells")
return list(cells.values())
def get_obj_from_sequence(obj, index):
"""Implement `tuple_getitem`."""
if not isinstance(obj, (tuple, list)):

View File

@ -15,7 +15,7 @@
""" test a list of cell, and getattr by its item """
import pytest
import numpy as np
from mindspore import context, nn, dtype, Tensor, ms_function
from mindspore import context, nn, dtype, Tensor, ms_function, ms_class
from mindspore.ops import operations as P
@ -24,6 +24,12 @@ class Actor(nn.Cell):
return x + y
@ms_class
class Actor2:
def act(self, x, y):
return x + y
class Trainer(nn.Cell):
def __init__(self, net_list):
super(Trainer, self).__init__()
@ -33,6 +39,20 @@ class Trainer(nn.Cell):
return self.net_list[0].act(x, y)
def verify_list_item_getattr(trainer, expect_res):
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_list_item_getattr():
"""
Feature: getattr by the item from list of cell.
@ -42,15 +62,16 @@ def test_list_item_getattr():
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor()]
trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([9], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
verify_list_item_getattr(trainer, expect_res)
@pytest.mark.skip(reason='Not support in graph mode yet')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cell_list_getattr():
"""
Feature: getattr by the item from nn.CellList.
@ -62,12 +83,27 @@ def test_cell_list_getattr():
for _ in range(3):
actor_list.append(Actor())
trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([9], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
verify_list_item_getattr(trainer, expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_msclass_list_getattr():
"""
Feature: getattr by the item from list of ms_class.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor2()]
trainer = Trainer(actor_list)
expect_res = Tensor([9], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
class Trainer2(nn.Cell):
@ -86,6 +122,12 @@ class Trainer2(nn.Cell):
return sum_value
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_list_item_getattr2():
"""
Feature: getattr by the item from list of cell with a Tensor variable.
@ -95,15 +137,16 @@ def test_list_item_getattr2():
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor(), Actor(), Actor()]
trainer = Trainer2(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([27], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
verify_list_item_getattr(trainer, expect_res)
@pytest.mark.skip(reason='Not support in graph mode yet')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cell_list_getattr2():
"""
Feature: getattr by the item from nn.CellList.
@ -115,12 +158,27 @@ def test_cell_list_getattr2():
for _ in range(3):
actor_list.append(Actor())
trainer = Trainer2(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
res = trainer(x, y)
print(f'res: {res}')
expect_res = Tensor([27], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
verify_list_item_getattr(trainer, expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_msclass_list_getattr2():
"""
Feature: getattr by the item from list of ms_class with a Tensor variable.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor2(), Actor2(), Actor2()]
trainer = Trainer2(actor_list)
expect_res = Tensor([27], dtype=dtype.float32)
verify_list_item_getattr(trainer, expect_res)
class MSRL(nn.Cell):
@ -154,7 +212,20 @@ class Trainer3(nn.Cell):
return output
@pytest.mark.skip(reason='Not support in graph mode yet')
def verify_list_item_getattr2(trainer, expect_res):
x = Tensor([2], dtype=dtype.int32)
y = Tensor([3], dtype=dtype.int32)
res = trainer.test(x, y)
print(f'res: {res}')
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_list_item_getattr3():
"""
Feature: getattr by the item from list of cell.
@ -168,15 +239,16 @@ def test_list_item_getattr3():
agent_list.append(Agent(actor))
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
x = Tensor([2], dtype=dtype.int32)
y = Tensor([3], dtype=dtype.int32)
res = trainer.test(x, y)
print(f'res: {res}')
expect_res = Tensor([15], dtype=dtype.int32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
verify_list_item_getattr2(trainer, expect_res)
@pytest.mark.skip(reason='Not support in graph mode yet')
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cell_list_getattr3():
"""
Feature: getattr by the item from list of cell.
@ -190,9 +262,28 @@ def test_cell_list_getattr3():
agent_list.append(Agent(actor))
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
x = Tensor([2], dtype=dtype.float32)
y = Tensor([3], dtype=dtype.float32)
res = trainer.test(x, y)
print(f'res: {res}')
expect_res = Tensor([15], dtype=dtype.float32)
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
expect_res = Tensor([15], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_msclass_list_getattr3():
"""
Feature: getattr by the item from list of ms_class.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
agent_list = []
for _ in range(3):
actor = Actor2()
agent_list.append(Agent(actor))
msrl = MSRL(agent_list)
trainer = Trainer3(msrl)
expect_res = Tensor([15], dtype=dtype.int32)
verify_list_item_getattr2(trainer, expect_res)