forked from mindspore-Ecosystem/mindspore
Support getattr of list of CellList or MsClass list.
This commit is contained in:
parent
517ece4ac8
commit
198cabb999
|
@ -31,35 +31,15 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
auto object_node = object.GetNode(node);
|
||||
auto attr_node = attr.GetNode(node);
|
||||
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
|
||||
if (parse::IsGetItemCNode(object_node)) {
|
||||
return parse::ResolveGetItemWithAttr(optimizer->manager(), object_node, attr_node, node);
|
||||
}
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(object_node);
|
||||
auto module_name = name_space->module();
|
||||
constexpr std::string_view parse_super_name = "namespace";
|
||||
if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
|
||||
symbol->symbol() != parse_super_name) {
|
||||
auto symbol_obj = parse::GetSymbolObject(name_space, symbol, node);
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), symbol_obj, object_node, attr_node);
|
||||
}
|
||||
return parse::ResolveSymbolWithAttr(optimizer->manager(), object_node, attr_node, node);
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
if (parse::IsGetItemCNode(object_node)) {
|
||||
auto getitem_cnode = object_node->cast<CNodePtr>();
|
||||
constexpr auto resolve_index = 1;
|
||||
constexpr auto index_index = 2;
|
||||
auto resolve_node = getitem_cnode->input(resolve_index);
|
||||
auto index_node = getitem_cnode->input(index_index);
|
||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
||||
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr_node, getitem_cnode);
|
||||
}
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr_node);
|
||||
}
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
if (IsValueNode<parse::NameSpace>(object_node)) {
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(object_node);
|
||||
|
@ -67,14 +47,12 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(attr_str);
|
||||
return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node);
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, MsClassObject, attr}
|
||||
if (IsValueNode<parse::MsClassObject>(object_node)) {
|
||||
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node);
|
||||
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node)->obj();
|
||||
auto attr_str = GetValue<std::string>(GetValueNode(attr_node));
|
||||
return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
if (IsValueNode<BoolImm>(object_node)) {
|
||||
return object_node;
|
||||
|
|
|
@ -668,6 +668,9 @@ bool IsCellInstance(const py::object &obj) {
|
|||
return is_cell;
|
||||
}
|
||||
|
||||
// Check if the object is MsClass instance.
|
||||
bool IsMsClassInstance(const py::object &obj) { return py::hasattr(obj, PYTHON_MS_CLASS); }
|
||||
|
||||
// Check if the object is class type.
|
||||
bool IsClassType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
|
|
|
@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
|
|||
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
||||
|
||||
bool IsCellInstance(const py::object &obj);
|
||||
bool IsMsClassInstance(const py::object &obj);
|
||||
bool IsClassType(const py::object &obj);
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
|
||||
|
|
|
@ -75,6 +75,8 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
|
|||
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||
const char PYTHON_MOD_IS_CELL_LIST[] = "is_cell_list";
|
||||
const char PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE[] = "convert_cell_list_to_sequence";
|
||||
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
||||
|
|
|
@ -376,71 +376,21 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
|
|||
}
|
||||
} // namespace
|
||||
|
||||
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
py::object obj = GetSymbolObject(name_space, symbol, node);
|
||||
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
|
||||
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
|
||||
if (imm_value == nullptr) {
|
||||
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
|
||||
<< ", index_node: " << index_node->DebugString();
|
||||
// Index is not fixed, return the whole list.
|
||||
return obj;
|
||||
}
|
||||
// It index is a value node, get the item of index directly.
|
||||
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
auto index = imm_value->value();
|
||||
py::object item_obj = python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
|
||||
return item_obj;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
|
||||
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
|
||||
const CNodePtr &operand_cnode) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto sequence = obj.cast<py::sequence>();
|
||||
// Incorporate if all elements of the sequence are Cell instances.
|
||||
for (size_t i = 0; i < sequence.size(); ++i) {
|
||||
if (!parse::data_converter::IsCellInstance(sequence[i])) {
|
||||
return nullptr;
|
||||
}
|
||||
// Resolve Cell instance.
|
||||
auto res = parse::ResolveCellWithAttr(manager, sequence[i], resolve_node, attr);
|
||||
inputs.emplace_back(res);
|
||||
}
|
||||
|
||||
constexpr auto prim_index = 0;
|
||||
constexpr auto index_index = 2;
|
||||
auto fg = operand_cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto make_tuple_node = fg->NewCNodeInOrder(inputs);
|
||||
return fg->NewCNodeInOrder({operand_cnode->input(prim_index), make_tuple_node, operand_cnode->input(index_index)});
|
||||
}
|
||||
|
||||
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
|
||||
std::pair<NameSpacePtr, SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
|
||||
auto resolve_cnode = node->cast<CNodePtr>();
|
||||
constexpr size_t namespace_index = 1;
|
||||
auto namespace_node = resolve_cnode->input(namespace_index);
|
||||
constexpr size_t symbol_index = 2;
|
||||
auto symbol_node = resolve_cnode->input(symbol_index);
|
||||
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
|
||||
if (!IsValueNode<NameSpace>(namespace_node) || !IsValueNode<Symbol>(symbol_node)) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
|
||||
<< ", symbol: " << symbol_node->ToString();
|
||||
}
|
||||
// Deal with the case of GetAttr from a class member,
|
||||
// and avoid the case of GetAttr from self (the result of ParseSuper).
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(namespace_node);
|
||||
auto symbol = GetValueNode<parse::SymbolPtr>(symbol_node);
|
||||
auto name_space = GetValueNode<NameSpacePtr>(namespace_node);
|
||||
auto symbol = GetValueNode<SymbolPtr>(symbol_node);
|
||||
return {name_space, symbol};
|
||||
}
|
||||
constexpr auto recursive_level = 2;
|
||||
|
@ -491,9 +441,8 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
|
|||
return res_node;
|
||||
}
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
py::object namespace_obj = python_adapter::GetPyFn(module, fn)(obj);
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object namespace_obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
|
||||
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
|
||||
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
|
||||
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
|
||||
|
@ -506,13 +455,99 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
|
|||
return resolved_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
|
||||
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
|
||||
const CNodePtr &operand_cnode) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto sequence = obj.cast<py::sequence>();
|
||||
// Incorporate if all elements of the sequence are Cell instances or MsClass instances.
|
||||
size_t count_cell = 0;
|
||||
size_t count_msclass = 0;
|
||||
size_t sequence_size = sequence.size();
|
||||
for (size_t i = 0; i < sequence_size; ++i) {
|
||||
if (data_converter::IsCellInstance(sequence[i])) {
|
||||
++count_cell;
|
||||
} else if (data_converter::IsMsClassInstance(sequence[i])) {
|
||||
++count_msclass;
|
||||
}
|
||||
}
|
||||
if (count_cell == sequence_size) {
|
||||
// Resolve Cell instances.
|
||||
for (size_t i = 0; i < sequence_size; ++i) {
|
||||
auto res = ResolveCellWithAttr(manager, sequence[i], resolve_node, attr);
|
||||
inputs.emplace_back(res);
|
||||
}
|
||||
} else if (count_msclass == sequence_size) {
|
||||
// Resolve MsClass instances.
|
||||
for (size_t i = 0; i < sequence_size; ++i) {
|
||||
auto attr_str = GetValue<std::string>(GetValueNode(attr));
|
||||
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, operand_cnode);
|
||||
inputs.emplace_back(res);
|
||||
}
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
constexpr auto prim_index = 0;
|
||||
constexpr auto index_index = 2;
|
||||
auto fg = operand_cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto make_tuple_node = fg->NewCNodeInOrder(inputs);
|
||||
return fg->NewCNodeInOrder({operand_cnode->input(prim_index), make_tuple_node, operand_cnode->input(index_index)});
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
auto [name_space, symbol] = GetNamespaceAndSymbol(object_node);
|
||||
auto module_name = name_space->module();
|
||||
constexpr std::string_view parse_super_name = "namespace";
|
||||
if (module_name.find(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
|
||||
symbol->symbol() != parse_super_name) {
|
||||
auto symbol_obj = GetSymbolObject(name_space, symbol, node);
|
||||
return ResolveCellWithAttr(manager, symbol_obj, object_node, attr_node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
py::object obj = GetSymbolObject(name_space, symbol, node);
|
||||
// If obj is nn.CellList, convert it to sequence.
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
bool is_celllist = py::cast<bool>(python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CELL_LIST, obj));
|
||||
if (is_celllist) {
|
||||
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CONVERT_CELL_LIST_TO_SEQUENCE, obj);
|
||||
}
|
||||
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString();
|
||||
auto imm_value = GetValueNode<Int64ImmPtr>(index_node);
|
||||
if (imm_value == nullptr) {
|
||||
MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString()
|
||||
<< ", index_node: " << index_node->DebugString();
|
||||
// Index is not fixed, return the whole list.
|
||||
return obj;
|
||||
}
|
||||
// It index is a value node, get the item of index directly.
|
||||
py::object item_obj =
|
||||
python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_ITEM_FROM_SEQUENCE, obj, py::int_(imm_value->value()));
|
||||
return item_obj;
|
||||
}
|
||||
|
||||
bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
|
||||
// Check if the node matches: {prim::kPrim::Resolve, ..., 'getitem'}.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
|
||||
constexpr size_t symbol_index = 2;
|
||||
constexpr auto getitem_symbol = "getitem";
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto symbol = GetValueNode<parse::SymbolPtr>(cnode->input(symbol_index));
|
||||
auto symbol = GetValueNode<SymbolPtr>(cnode->input(symbol_index));
|
||||
return symbol->symbol() == getitem_symbol;
|
||||
}
|
||||
return false;
|
||||
|
@ -531,21 +566,58 @@ bool IsGetItemCNode(const AnfNodePtr &node) {
|
|||
return IsResolveNodeWithGetItem(cnode->input(prim_index));
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class,
|
||||
AnfNodePtr ResolveGetItemInner(const FuncGraphManagerPtr &manager, const AnfNodePtr &data_node,
|
||||
const AnfNodePtr &index_node, const CNodePtr &getitem_cnode,
|
||||
const AnfNodePtr &attr_node) {
|
||||
auto [name_space, symbol] = GetNamespaceAndSymbol(data_node);
|
||||
auto obj = GetObjectFromSequence(name_space, symbol, data_node, index_node);
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
return ResolveSequenceWithAttr(manager, obj, data_node, attr_node, getitem_cnode);
|
||||
}
|
||||
return ResolveCellWithAttr(manager, obj, data_node, attr_node);
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node) {
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimGetAttr, ResolveNode, member}, index}, attr}
|
||||
constexpr auto data_index = 1;
|
||||
constexpr auto index_index = 2;
|
||||
auto getitem_cnode = getitem_node->cast<CNodePtr>();
|
||||
auto data_node = getitem_cnode->input(data_index);
|
||||
auto index_node = getitem_cnode->input(index_index);
|
||||
if (IsPrimitiveCNode(data_node, prim::kPrimResolve)) {
|
||||
return ResolveGetItemInner(manager, data_node, index_node, getitem_cnode, attr_node);
|
||||
}
|
||||
if (IsPrimitiveCNode(data_node, prim::kPrimGetAttr)) {
|
||||
auto getattr_cnode = data_node->cast<CNodePtr>();
|
||||
auto resolve_node = getattr_cnode->input(data_index);
|
||||
auto member_node = getattr_cnode->input(index_index);
|
||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
// Check if the result is a new resolve node.
|
||||
auto item_node = ResolveSymbolWithAttr(manager, resolve_node, member_node, node);
|
||||
if (IsPrimitiveCNode(item_node, prim::kPrimResolve)) {
|
||||
return ResolveGetItemInner(manager, item_node, index_node, getitem_cnode, attr_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &cls_obj,
|
||||
const std::string &attr, const AnfNodePtr &node) {
|
||||
// Get attribute or method from ms_class obj.
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
|
||||
MS_LOG(DEBUG) << "Resolve ms_class obj (" << py::str(cls_obj) << ") with attr " << attr << ".";
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
|
||||
constexpr size_t prefix_index = 0;
|
||||
if (attr.size() > 0 && attr[prefix_index] == '_') {
|
||||
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
|
||||
}
|
||||
py::object cls_obj = ms_class->obj();
|
||||
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
|
||||
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
|
||||
MS_LOG(EXCEPTION) << py::str(cls_obj) << " has not attribute: " << attr << ".";
|
||||
}
|
||||
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
|
||||
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
|
||||
|
|
|
@ -163,23 +163,14 @@ class ClassType final : public PyObjectWrapper {
|
|||
};
|
||||
using ClassTypePtr = std::shared_ptr<ClassType>;
|
||||
|
||||
// Get python object with index from a list or the whole list if the index is not fixed.
|
||||
py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node);
|
||||
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
|
||||
|
||||
// Get resolved python object by namespace and symbol.
|
||||
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node);
|
||||
// Resolve symbol in namespace.
|
||||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node);
|
||||
// Resolve Cell with attr name.
|
||||
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
|
||||
const AnfNodePtr &attr);
|
||||
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
|
||||
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
|
||||
const CNodePtr &operand_cnode);
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class,
|
||||
AnfNodePtr ResolveSymbolWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &object_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node);
|
||||
AnfNodePtr ResolveGetItemWithAttr(const FuncGraphManagerPtr &manager, const AnfNodePtr &getitem_node,
|
||||
const AnfNodePtr &attr_node, const AnfNodePtr &node);
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::object &ms_class,
|
||||
const std::string &attr, const AnfNodePtr &node);
|
||||
|
||||
// Check if node is cnode with getitem.
|
||||
|
|
|
@ -1355,7 +1355,7 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &e
|
|||
// Get the attr/method of ms_class object.
|
||||
auto out_node = out_conf->node();
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node);
|
||||
auto new_node = parse::ResolveMsClassWithAttr(func_graph->manager(), ms_class->obj(), item_name, out_node);
|
||||
// Replace old node with the resolved new node in order list.
|
||||
func_graph->ReplaceInOrder(out_node, new_node);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
|
|
|
@ -23,8 +23,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
|
||||
eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, get_dataclass_attributes, get_dataclass_methods, check_obj_bool,
|
||||
python_isinstance, ms_isinstance)
|
||||
is_class_type, get_dataclass_attributes, get_dataclass_methods, check_obj_bool, is_cell_list,
|
||||
python_isinstance, ms_isinstance, convert_cell_list_to_sequence, get_obj_from_sequence)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
@ -33,5 +33,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
|||
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
|
||||
'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods', 'check_obj_bool', 'python_isinstance',
|
||||
'ms_isinstance']
|
||||
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods', 'check_obj_bool', 'is_cell_list',
|
||||
'python_isinstance', 'ms_isinstance', 'convert_cell_list_to_sequence', 'get_obj_from_sequence']
|
||||
|
|
|
@ -426,6 +426,21 @@ def ms_isinstance(x, cmp_type):
|
|||
return isinstance(x, pytype_to_mstype.get(cmp_type))
|
||||
|
||||
|
||||
def is_cell_list(obj):
|
||||
"""Check if obj is nn.CellList"""
|
||||
return isinstance(obj, nn.CellList)
|
||||
|
||||
|
||||
def convert_cell_list_to_sequence(obj):
|
||||
"""Convert nn.CellList to sequence."""
|
||||
if not isinstance(obj, nn.CellList):
|
||||
raise TypeError(f"Obj should be nn.CellList, but got {obj}")
|
||||
if not hasattr(obj, "_cells"):
|
||||
raise AttributeError(f"nn.CellList is missing _cells property.")
|
||||
cells = getattr(obj, "_cells")
|
||||
return list(cells.values())
|
||||
|
||||
|
||||
def get_obj_from_sequence(obj, index):
|
||||
"""Implement `tuple_getitem`."""
|
||||
if not isinstance(obj, (tuple, list)):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test a list of cell, and getattr by its item """
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context, nn, dtype, Tensor, ms_function
|
||||
from mindspore import context, nn, dtype, Tensor, ms_function, ms_class
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
@ -24,6 +24,12 @@ class Actor(nn.Cell):
|
|||
return x + y
|
||||
|
||||
|
||||
@ms_class
|
||||
class Actor2:
|
||||
def act(self, x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
class Trainer(nn.Cell):
|
||||
def __init__(self, net_list):
|
||||
super(Trainer, self).__init__()
|
||||
|
@ -33,6 +39,20 @@ class Trainer(nn.Cell):
|
|||
return self.net_list[0].act(x, y)
|
||||
|
||||
|
||||
def verify_list_item_getattr(trainer, expect_res):
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_item_getattr():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell.
|
||||
|
@ -42,15 +62,16 @@ def test_list_item_getattr():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = [Actor()]
|
||||
trainer = Trainer(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_list_getattr():
|
||||
"""
|
||||
Feature: getattr by the item from nn.CellList.
|
||||
|
@ -62,12 +83,27 @@ def test_cell_list_getattr():
|
|||
for _ in range(3):
|
||||
actor_list.append(Actor())
|
||||
trainer = Trainer(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_msclass_list_getattr():
|
||||
"""
|
||||
Feature: getattr by the item from list of ms_class.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = [Actor2()]
|
||||
trainer = Trainer(actor_list)
|
||||
expect_res = Tensor([9], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
|
||||
|
||||
class Trainer2(nn.Cell):
|
||||
|
@ -86,6 +122,12 @@ class Trainer2(nn.Cell):
|
|||
return sum_value
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_item_getattr2():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell with a Tensor variable.
|
||||
|
@ -95,15 +137,16 @@ def test_list_item_getattr2():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = [Actor(), Actor(), Actor()]
|
||||
trainer = Trainer2(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_list_getattr2():
|
||||
"""
|
||||
Feature: getattr by the item from nn.CellList.
|
||||
|
@ -115,12 +158,27 @@ def test_cell_list_getattr2():
|
|||
for _ in range(3):
|
||||
actor_list.append(Actor())
|
||||
trainer = Trainer2(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
res = trainer(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_msclass_list_getattr2():
|
||||
"""
|
||||
Feature: getattr by the item from list of ms_class with a Tensor variable.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = [Actor2(), Actor2(), Actor2()]
|
||||
trainer = Trainer2(actor_list)
|
||||
expect_res = Tensor([27], dtype=dtype.float32)
|
||||
verify_list_item_getattr(trainer, expect_res)
|
||||
|
||||
|
||||
class MSRL(nn.Cell):
|
||||
|
@ -154,7 +212,20 @@ class Trainer3(nn.Cell):
|
|||
return output
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
def verify_list_item_getattr2(trainer, expect_res):
|
||||
x = Tensor([2], dtype=dtype.int32)
|
||||
y = Tensor([3], dtype=dtype.int32)
|
||||
res = trainer.test(x, y)
|
||||
print(f'res: {res}')
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_item_getattr3():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell.
|
||||
|
@ -168,15 +239,16 @@ def test_list_item_getattr3():
|
|||
agent_list.append(Agent(actor))
|
||||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
x = Tensor([2], dtype=dtype.int32)
|
||||
y = Tensor([3], dtype=dtype.int32)
|
||||
res = trainer.test(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
verify_list_item_getattr2(trainer, expect_res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph mode yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_list_getattr3():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell.
|
||||
|
@ -190,9 +262,28 @@ def test_cell_list_getattr3():
|
|||
agent_list.append(Agent(actor))
|
||||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
x = Tensor([2], dtype=dtype.float32)
|
||||
y = Tensor([3], dtype=dtype.float32)
|
||||
res = trainer.test(x, y)
|
||||
print(f'res: {res}')
|
||||
expect_res = Tensor([15], dtype=dtype.float32)
|
||||
assert np.array_equal(res.asnumpy(), expect_res.asnumpy())
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_msclass_list_getattr3():
|
||||
"""
|
||||
Feature: getattr by the item from list of ms_class.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
agent_list = []
|
||||
for _ in range(3):
|
||||
actor = Actor2()
|
||||
agent_list.append(Agent(actor))
|
||||
msrl = MSRL(agent_list)
|
||||
trainer = Trainer3(msrl)
|
||||
expect_res = Tensor([15], dtype=dtype.int32)
|
||||
verify_list_item_getattr2(trainer, expect_res)
|
||||
|
|
Loading…
Reference in New Issue