Support getattr by the item of cell list: Handle the DoSignaturePrimitive('getitem') between getattr and resolve.

This commit is contained in:
Zhang Qinghua 2021-12-10 14:25:18 +08:00
parent c2a47674ad
commit 3e7b73e6c7
12 changed files with 289 additions and 124 deletions

View File

@ -174,15 +174,15 @@ def resolve_symbol(namespace, symbol):
# If need trope the obj # If need trope the obj
if resolve_ in convert_object_map: if resolve_ in convert_object_map:
resolve_ = convert_object_map.get(resolve_) resolve_ = convert_object_map.get(resolve_)
logger.debug("Convert resolve = %r", resolve_) logger.debug("Convert resolve: %r", resolve_)
if resolve_ == NO_IMPLEMENT: if resolve_ == NO_IMPLEMENT:
raise NotImplementedError(f"Not support for '{symbol}'.") raise NotImplementedError(f"Not support for '{symbol}'.")
except Exception as e: except Exception as e:
if isinstance(e, NotImplementedError): if isinstance(e, NotImplementedError):
raise e raise e
resolve_ = None resolve_ = None
logger.debug("Resolve exception occurred, value = %r", e) logger.debug("Resolve exception occurred, value: %r", e)
logger.debug("Resolve type is invalid, namespace = %s, symbol = %s", logger.debug("Resolve type is invalid, namespace: %s, symbol: %s",
namespace.__str__(), symbol) namespace.__str__(), symbol)
if isinstance(resolve_, _MindsporeFunctionExecutor): if isinstance(resolve_, _MindsporeFunctionExecutor):
@ -219,7 +219,7 @@ def get_object_key(obj):
if hasattr(obj, "cell_init_args"): if hasattr(obj, "cell_init_args"):
obj_key = "%s_ID" % (tag + obj.cell_init_args) obj_key = "%s_ID" % (tag + obj.cell_init_args)
obj_id = "%s_ID%d" % (tag, id(obj)) obj_id = "%s_ID%d" % (tag, id(obj))
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id) logger.debug("obj_key: %s, obj_id: %s", obj_key, obj_id)
# method has same id of different instance # method has same id of different instance
if isinstance(obj, types.MethodType): if isinstance(obj, types.MethodType):
@ -339,9 +339,17 @@ def create_instance(cls_type, params=None):
return obj return obj
def get_obj_from_sequence(obj, index):
"""Implement `tuple_getitem`."""
if not isinstance(obj, (tuple, list)):
raise TypeError(f"Should not get item from a object that not sequence type, obj: {obj}")
# Not check index out of range by self.
return obj[index]
def get_module_namespace(obj): def get_module_namespace(obj):
"""Get the module's namespace.""" """Get the module's namespace."""
logger.debug("get module namespace, module = %r", obj) logger.debug("get module namespace, module: %r", obj)
mod_namespace = None mod_namespace = None
if isinstance(obj, types.ModuleType): if isinstance(obj, types.ModuleType):
mod_namespace = CellNamespace(obj.__name__) mod_namespace = CellNamespace(obj.__name__)
@ -352,9 +360,9 @@ def get_module_namespace(obj):
def get_class_member_namespace_symbol(obj): def get_class_member_namespace_symbol(obj):
"""Get obj class member type.""" """Get obj class member type."""
logger.debug("get class instance namespace, object = %r", obj) logger.debug("get class instance namespace, object: %r", obj)
class_namespace = ClassMemberNamespace(obj) class_namespace = ClassMemberNamespace(obj)
logger.debug("class namesapce = %r", class_namespace) logger.debug("class namespace: %r", class_namespace)
return class_namespace return class_namespace
@ -425,14 +433,14 @@ def get_ast_namespace_symbol(obj):
"""Get obj type and namespace and symbol.""" """Get obj type and namespace and symbol."""
# step 1:get symbol from object map # step 1:get symbol from object map
ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE) ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
logger.debug("ops info = %r", ops_info) logger.debug("ops info: %r", ops_info)
return ops_info return ops_info
def get_operation_namespace_symbol(var: str): def get_operation_namespace_symbol(var: str):
"""Get operation namespace and symbol.""" """Get operation namespace and symbol."""
ops_info = (trope_ns, var) ops_info = (trope_ns, var)
logger.debug("get operation ops info = %r", ops_info) logger.debug("get operation ops info: %r", ops_info)
return ops_info return ops_info
@ -566,7 +574,7 @@ class Parser:
def parse(self): def parse(self):
"""Parse the function or method.""" """Parse the function or method."""
logger.debug("fn = %r", self.fn) logger.debug("fn: %r", self.fn)
if isinstance(self.fn, (types.FunctionType, types.MethodType)): if isinstance(self.fn, (types.FunctionType, types.MethodType)):
try: try:
lines, self.line_offset = inspect.getsourcelines(self.fn) lines, self.line_offset = inspect.getsourcelines(self.fn)
@ -582,7 +590,7 @@ class Parser:
src = dedent(original_src) src = dedent(original_src)
self.col_offset = \ self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0]) len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("Get source = %s", src) logger.debug("Get source: %s", src)
try: try:
ast_tokens = asttokens.ASTTokens(src, parse=True) ast_tokens = asttokens.ASTTokens(src, parse=True)
except IndentationError as idt_err: except IndentationError as idt_err:

View File

@ -176,7 +176,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMap<AnfNodePtr, int32_t> *para_map) { int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMap<AnfNodePtr, int32_t> *para_map) {
if (graph == nullptr) { if (graph == nullptr) {
MS_LOG(INFO) << "Param graph is nullptr."; MS_LOG(INFO) << "Parameter \'graph\' should not be null.";
return 0; return 0;
} }
std::vector<AnfNodePtr> parameters = graph->parameters(); std::vector<AnfNodePtr> parameters = graph->parameters();
@ -217,13 +217,12 @@ int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, Ordere
void DumpOperator(const AnfNodePtr &node, const std::shared_ptr<SubGraphIRInfo> &gsub) { void DumpOperator(const AnfNodePtr &node, const std::shared_ptr<SubGraphIRInfo> &gsub) {
if (gsub == nullptr) { if (gsub == nullptr) {
MS_LOG(INFO) << "Param gsub is nullptr"; MS_LOG(INFO) << "Parameter \'gsub\' should not be null.";
return; return;
} }
auto cnode = dyn_cast<CNode>(node); auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
if (cnode == nullptr) { if (cnode == nullptr) {
MS_LOG(INFO) << "Param node should be a CNode"; MS_LOG(EXCEPTION) << "Parameter \'node\' should be a CNode";
return; return;
} }
AnfNodePtr op = cnode->input(0); AnfNodePtr op = cnode->input(0);
@ -272,7 +271,11 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
gsub->buffer << ", "; gsub->buffer << ", ";
} }
if (in->isa<Parameter>()) { if (in->isa<Parameter>()) {
if (in->func_graph() != node->func_graph()) { MS_EXCEPTION_IF_NULL(node->func_graph());
if (in->func_graph() == nullptr) {
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
}
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":"; gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
} else { } else {
gsub->buffer << "%"; gsub->buffer << "%";
@ -283,7 +286,7 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
} else { } else {
gsub->buffer << "para" << iter->second << "_" << in->ToString(); gsub->buffer << "para" << iter->second << "_" << in->ToString();
} }
if (in->func_graph() != node->func_graph()) { if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << ")"; gsub->buffer << ")";
} }
} else if (in->isa<CNode>()) { } else if (in->isa<CNode>()) {
@ -325,7 +328,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
} }
ValuePtr in_tmp = MakeValue(in_strategy->GetInputDim()); ValuePtr in_tmp = MakeValue(in_strategy->GetInputDim());
gsub->buffer << " { in_strategy: "; gsub->buffer << " {in_strategy: ";
gsub->buffer << in_tmp->ToString(); gsub->buffer << in_tmp->ToString();
auto out_strategy = operator_info->out_strategy(); auto out_strategy = operator_info->out_strategy();
@ -335,7 +338,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
gsub->buffer << out_tmp->ToString(); gsub->buffer << out_tmp->ToString();
} }
gsub->buffer << " }"; gsub->buffer << "}";
} }
void DumpAttrs(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::shared_ptr<SubGraphIRInfo> &gsub, void DumpAttrs(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::shared_ptr<SubGraphIRInfo> &gsub,

View File

@ -267,9 +267,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {
// In resolver_getattr_resolve_, some patterns have priority over others. // In resolver_, some patterns have priority over others.
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve", resolver_ = MakeSubstitution(std::make_shared<Resolver>(), "getattr_resolve",
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true); {prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
} }
InferenceOptPrepareLib::InferenceOptPrepareLib() { InferenceOptPrepareLib::InferenceOptPrepareLib() {

View File

@ -169,7 +169,7 @@ class ResolveIRPassLib {
public: public:
ResolveIRPassLib(); ResolveIRPassLib();
~ResolveIRPassLib() = default; ~ResolveIRPassLib() = default;
SubstitutionPtr resolver_getattr_resolve_; SubstitutionPtr resolver_;
}; };
class InferenceOptPrepareLib { class InferenceOptPrepareLib {

View File

@ -434,7 +434,8 @@ class PynativeEliminater : public OptimizerCaller {
if (value_node == nullptr) { if (value_node == nullptr) {
return false; return false;
} }
return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value; auto module_name = GetValueNode<parse::NameSpacePtr>(value_node)->module();
return module_name.find(str_value) != std::string::npos;
} }
bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) { bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) {

View File

@ -23,53 +23,92 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr} // {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, bool, attr} // {prim::kPrimGetAttr, bool, attr}
// {prim::kPrimResolve, namespace, symbol} // {prim::kPrimResolve, namespace, symbol}
AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
constexpr std::string_view PARSE_SUPER_NAME = "namespace"; PatternNode<AnfNodePtr> getattr_operand, ns_node, sym_node, attr_node, bool_node;
constexpr size_t namespace_index = 1; auto GetAttrResolveLambda = [&node, &getattr_operand, &attr_node, &optimizer]() -> AnfNodePtr {
constexpr size_t symbol_index = 2; auto getattr_operand_node = getattr_operand.GetNode(node);
PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer, &PARSE_SUPER_NAME]() -> AnfNodePtr {
auto inner = resolve_node.GetNode(node);
auto attr = attr_node.GetNode(node); auto attr = attr_node.GetNode(node);
if (IsPrimitiveCNode(inner, prim::kPrimResolve)) { constexpr auto recursive_level = 3;
auto resolve_cnode = inner->cast<CNodePtr>(); MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
auto namespace_node = resolve_cnode->input(namespace_index);
auto symbol_node = resolve_cnode->input(symbol_index); // {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}}
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) { auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
return nullptr; if (getitem_cnode != nullptr) {
constexpr size_t prim_index = 0;
auto primitive_node = getitem_cnode->input(prim_index);
auto resolved_getitem_node = primitive_node;
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) {
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>();
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
constexpr auto getitem_symbol = "getitem";
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
resolved_getitem_node =
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
}
} }
// deal with the case of getting attr from a class member bool is_getattr_getitem = false;
// and avoid the case of getting attr from self (the result of ParseSuper) auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node); if (do_signature != nullptr) {
auto sym = GetValueNode<parse::SymbolPtr>(symbol_node); auto &func_value = do_signature->function();
if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) { // The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr); auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
constexpr auto getitem_symbol = "getitem";
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
is_getattr_getitem = true;
}
}
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
is_getattr_getitem = true;
}
if (is_getattr_getitem) {
constexpr size_t resolve_index = 1;
auto resolve_node = getitem_cnode->input(resolve_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
constexpr size_t position_index = 2;
auto index_node = getitem_cnode->input(position_index);
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node);
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr);
}
}
}
// {prim::GetAttr, {prim::Resolve, ...}}
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimResolve)) {
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(getattr_operand_node);
auto module_name = name_space->module();
constexpr std::string_view parse_super_name = "namespace";
if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
symbol->symbol() != parse_super_name) {
auto obj = parse::GetSymbolObject(name_space, symbol, node);
return parse::ResolveCellWithAttr(optimizer->manager(), obj, getattr_operand_node, attr);
} }
} }
return nullptr; return nullptr;
}; };
auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr { auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node)); auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node))); auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str); parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(str);
return parse::ResolveSymbol(optimizer->manager(), ns, sym, node); return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node);
}; };
auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr { auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node)); auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node)); auto symbol = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
auto manager = optimizer->manager(); auto manager = optimizer->manager();
return parse::ResolveSymbol(manager, ns, sym, node); return parse::ResolveSymbol(manager, name_space, symbol, node);
}; };
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda, MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, getattr_operand, attr_node), GetAttrResolveLambda,
attr_node.CheckFunc(IsValueNode<StringImm>, node)); attr_node.CheckFunc(IsValueNode<StringImm>, node));
// {prim::kPrimGetAttr, namespace, attr} // {prim::kPrimGetAttr, namespace, attr}
MATCH_REPLACE_LAMBDA_IF( MATCH_REPLACE_LAMBDA_IF(

View File

@ -38,11 +38,12 @@ namespace irpass {
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern. // pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
// The same is true for matching Resolve pattern. // The same is true for matching Resolve pattern.
// //
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr} // {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, bool, attr} // {prim::kPrimGetAttr, bool, attr}
// {prim::kPrimResolve, namespace, symbol} // {prim::kPrimResolve, namespace, symbol}
class ResolverGetAttrResolve : public OptimizerCaller { class Resolver : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
}; };

View File

@ -525,7 +525,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) { bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
// Check parameter valid // Check parameter valid
if (data == nullptr) { if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer"; MS_LOG(ERROR) << "The value pointer should not be null.";
return false; return false;
} }
ValuePtr converted = nullptr; ValuePtr converted = nullptr;
@ -554,7 +554,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
ValuePtr value = nullptr; ValuePtr value = nullptr;
bool is_cache = data_converter::GetObjectValue(obj_id, &value); bool is_cache = data_converter::GetObjectValue(obj_id, &value);
if (is_cache && value != nullptr && value->isa<FuncGraph>()) { if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_id;
func_graph = value->cast<FuncGraphPtr>(); func_graph = value->cast<FuncGraphPtr>();
if (!func_graph->dropped()) { if (!func_graph->dropped()) {
return func_graph; return func_graph;
@ -570,7 +570,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
data_converter::CacheObjectValue(obj_id, func_graph); data_converter::CacheObjectValue(obj_id, func_graph);
if (!obj_key.empty()) { if (!obj_key.empty()) {
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); MS_LOG(DEBUG) << "Add graph: " << obj_key << ", func_graph: " << func_graph->ToString();
data_converter::SetObjGraphValue(obj_key, func_graph); data_converter::SetObjGraphValue(obj_key, func_graph);
} }
@ -584,11 +584,11 @@ static mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> object_graphs_
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
object_graphs_map_[obj_key].push_back(data); object_graphs_map_[obj_key].push_back(data);
MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
} }
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() { const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
return object_graphs_map_; return object_graphs_map_;
} }
@ -606,7 +606,7 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
if (obj_tuple.size() != 2) { if (obj_tuple.size() != 2) {
MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements"; MS_LOG(EXCEPTION) << "The function of \'get_obj_key()\' must return 2 elements";
} }
return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])}; return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
} }
@ -619,10 +619,10 @@ ResolveTypeDef GetObjType(const py::object &obj) {
ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>()); ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
return obj_type; return obj_type;
} catch (const py::error_already_set &ex) { } catch (const py::error_already_set &ex) {
MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what(); MS_LOG(ERROR) << "Meet a exception from Python when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} catch (const py::type_error &ex) { } catch (const py::type_error &ex) {
MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what(); MS_LOG(ERROR) << "Meet a exception when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} }
} }
@ -638,8 +638,8 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
// Check the object is Cell Instance. // Check the object is Cell Instance.
bool IsCellInstance(const py::object &obj) { bool IsCellInstance(const py::object &obj) {
auto class_type = GetClassInstanceType(obj); auto class_type = GetClassInstanceType(obj);
bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
return isCell; return is_cell;
} }
// Create the python class instance. // Create the python class instance.

View File

@ -27,12 +27,12 @@
namespace py = pybind11; namespace py = pybind11;
namespace mindspore { namespace mindspore {
namespace parse { namespace parse {
// define the node type // Define the node type.
enum AstMainType : int64_t { enum AstMainType : int64_t {
AST_MAIN_TYPE_STMT = 0, // ast.Stmt AST_MAIN_TYPE_STMT = 0, // ast.Stmt
AST_MAIN_TYPE_EXPR = 1, // ast.Expr AST_MAIN_TYPE_EXPR = 1, // ast.Expr
AST_MAIN_TYPE_SLICE = 2, // ast.Slice AST_MAIN_TYPE_SLICE = 2, // ast.Slice
AST_MAIN_TYPE_UNKNOWN = 0xFF // Error AST_MAIN_TYPE_UNKNOWN = 0xFF // Unknown type
}; };
enum AstSubType : int64_t { enum AstSubType : int64_t {
@ -43,18 +43,18 @@ enum AstSubType : int64_t {
AST_SUB_TYPE_SUBSCRIPT = 7, // ast.Subscript AST_SUB_TYPE_SUBSCRIPT = 7, // ast.Subscript
AST_SUB_TYPE_STARRED = 8, // ast.Starred AST_SUB_TYPE_STARRED = 8, // ast.Starred
AST_SUB_TYPE_ATTRIBUTE = 9, // ast.Attribute AST_SUB_TYPE_ATTRIBUTE = 9, // ast.Attribute
AST_SUB_TYPE_UNKNOWN = 0xFF // Error AST_SUB_TYPE_UNKNOWN = 0xFF // Unknown type
}; };
// define the parse target type // Define the parse target type.
enum ParseTargetTypeDef { enum ParseTargetTypeDef {
PARSE_TARGET_FUNCTION = 0, // function PARSE_TARGET_FUNCTION = 0, // Function
PARSE_TARGET_METHOD = 1, // method PARSE_TARGET_METHOD = 1, // Method
PARSE_TARGET_OBJECT_INSTANCE = 2, // object instance PARSE_TARGET_OBJECT_INSTANCE = 2, // Object instance
PARSE_TARGET_UNKNOW = 0xFF // ERROR TYPE PARSE_TARGET_UNKNOW = 0xFF // Unknown type
}; };
// define python module name // Define python module name.
const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse"; const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb"; const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb";
const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol"; const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
@ -72,6 +72,7 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description"; const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script"; const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
@ -92,7 +93,7 @@ const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
// define the common name // Define the common name.
const char NAMED_PRIMITIVE_LEN[] = "len"; const char NAMED_PRIMITIVE_LEN[] = "len";
const char NAMED_PRIMITIVE_BODY[] = "body"; const char NAMED_PRIMITIVE_BODY[] = "body";
const char NAMED_PRIMITIVE_ASSIGN[] = "Assign"; const char NAMED_PRIMITIVE_ASSIGN[] = "Assign";
@ -127,8 +128,8 @@ const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";
const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict"; const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict";
const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call"; const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call";
// define NAMED_PRIMITIVE_GETATTR "getattr" // Define NAMED_PRIMITIVE_GETATTR "getattr".
// define python inline attr // Define python inline attr.
const char PYTHON_GET_METHOD_LEN[] = "__len__"; const char PYTHON_GET_METHOD_LEN[] = "__len__";
const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__"; const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__";
const char PYTHON_GET_OBJ_DESC[] = "__str__"; const char PYTHON_GET_OBJ_DESC[] = "__str__";
@ -136,46 +137,46 @@ const char PYTHON_GET_OBJ_DESC[] = "__str__";
const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__"; const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__";
const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant // Define the parse constant.
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop"; const char CUSTOM_BPROP_NAME[] = "bprop";
const char STAGE_NAME[] = "_pipeline_stage"; const char STAGE_NAME[] = "_pipeline_stage";
// define the Namespace name // Define the Namespace name.
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // For ast type namespace.
const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // for class member namespace const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // For class member namespace.
const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // for symbol str namespace const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // For symbol str namespace.
const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // for common ops, eg: hasnext, next const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // For common ops, eg: hasnext, next.
const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // fro Module namespace const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // For Module namespace.
// define Resolve type // Define Resolve type.
enum ResolveTypeDef : int64_t { enum ResolveTypeDef : int64_t {
RESOLVE_TYPE_NONE = 0, // resolve None RESOLVE_TYPE_NONE = 0, // Resolve None
RESOLVE_TYPE_FUNCTION = 1, // resolve function RESOLVE_TYPE_FUNCTION = 1, // Resolve function
RESOLVE_TYPE_METHOD = 2, // resolve class method RESOLVE_TYPE_METHOD = 2, // Resolve class method
RESOLVE_TYPE_CLASS_TYPE = 3, // resolve class type RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type
RESOLVE_TYPE_CLASS_INSTANCE = 4, // resolve the class instance of common class RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class
RESOLVE_TYPE_INVALID = 0xFF // resolve invalid RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid
}; };
// define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE // Define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE.
enum ClassInstanceTypeDef { enum ClassInstanceTypeDef {
CLASS_INSTANCE_TYPE_CELL = 0, // class instance type is Cell CLASS_INSTANCE_TYPE_CELL = 0, // Class instance type is Cell.
CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // class instance type is Primitive CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // Class instance type is Primitive.
CLASS_INSTANCE_TYPE_INVALID = 0xFF CLASS_INSTANCE_TYPE_INVALID = 0xFF
}; };
// Convert python object to ValuePtr // Convert python object to ValuePtr.
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr); bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr);
// Convert python obj to graph // Convert python obj to graph.
FuncGraphPtr ConvertToFuncGraph(const py::object &obj, FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
// Parse the python object to graph // Parse the python object to graph.
FuncGraphPtr ParsePythonCode(const py::object &obj, FuncGraphPtr ParsePythonCode(const py::object &obj,
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
// add wrap for cell top graph. // Add wrap for cell top graph.
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr); FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr);
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore

View File

@ -16,8 +16,9 @@
#include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/resolve.h"
#include <string> #include <utility>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "ir/param_info.h" #include "ir/param_info.h"
@ -310,25 +311,12 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
} }
} // namespace } // namespace
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, // Get python object with index from a list.
const AnfNodePtr &node) { py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info())); TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (node->func_graph() == nullptr || manager == nullptr) { if (node->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
symbol_resolver.Resolve();
py::object obj = symbol_resolver.result();
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
return resolved_node;
}
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
} }
SymbolResolver symbol_resolver(name_space, symbol, node); SymbolResolver symbol_resolver(name_space, symbol, node);
@ -337,6 +325,75 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
} }
py::object obj = symbol_resolver.result(); py::object obj = symbol_resolver.result();
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
}
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
const std::string module = "mindspore._extends.parse.parser";
int index = GetValueNode<Int64ImmPtr>(index_node)->value();
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
return item_obj;
}
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
auto resolve_cnode = node->cast<CNodePtr>();
constexpr size_t namespace_index = 1;
auto namespace_node = resolve_cnode->input(namespace_index);
constexpr size_t symbol_index = 2;
auto symbol_node = resolve_cnode->input(symbol_index);
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
<< ", symbol: " << symbol_node->ToString();
}
// Deal with the case of GetAttr from a class member,
// and avoid the case of GetAttr from self (the result of ParseSuper).
auto name_space = GetValueNode<parse::NameSpacePtr>(namespace_node);
auto symbol = GetValueNode<parse::SymbolPtr>(symbol_node);
return {name_space, symbol};
}
constexpr auto recursive_level = 2;
MS_LOG(EXCEPTION) << "It's not prim::Resolve CNode, node: " << node->DebugString(recursive_level);
}
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
symbol_resolver.Resolve();
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Fail to resolve node, NodeInfo.";
}
py::object obj = symbol_resolver.result();
return obj;
}
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Manager is nullptr.";
}
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
auto obj = GetSymbolObject(name_space, symbol, node);
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
return resolved_node;
}
// Resolve Cell GetAttr operation.
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
const AnfNodePtr &attr) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(attr);
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Manager is nullptr.";
}
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", attr: " << attr->ToString();
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (!data_converter::IsCellInstance(obj)) { if (!data_converter::IsCellInstance(obj)) {
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node); AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr}; AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
@ -365,7 +422,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &ir
opt::OptPassGroupMap map({ opt::OptPassGroupMap map({
{"resolve", {"resolve",
{ {
irpass.resolver_getattr_resolve_, irpass.resolver_,
}}, }},
}); });
return map; return map;

View File

@ -17,8 +17,10 @@
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_ #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_
#include <utility>
#include <memory> #include <memory>
#include <string> #include <string>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
@ -40,7 +42,10 @@ namespace parse {
class NameSpace final : public Named { class NameSpace final : public Named {
public: public:
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object()) NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
: Named(module), module_(module), obj_(obj), module_obj_(module_obj) {} : Named(module + ": \'" + std::string(py::str(obj)) + "\'"),
module_(module),
obj_(obj),
module_obj_(module_obj) {}
~NameSpace() override = default; ~NameSpace() override = default;
MS_DECLARE_PARENT(NameSpace, Named); MS_DECLARE_PARENT(NameSpace, Named);
@ -92,7 +97,7 @@ class Script final : public Named {
abstract::AbstractBasePtr ToAbstract() override { abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>()); return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
} }
std::string ToString() const override { return "`" + name() + "`"; } std::string ToString() const override { return "\'" + name() + "\'"; }
private: private:
std::string script_; std::string script_;
@ -116,7 +121,7 @@ class PyObjectWrapper : public Named {
class InterpretedObject final : public PyObjectWrapper { class InterpretedObject final : public PyObjectWrapper {
public: public:
explicit InterpretedObject(const py::object &obj, const std::string &name = "null") explicit InterpretedObject(const py::object &obj, const std::string &name = "null")
: PyObjectWrapper(obj, "InterpretedObject: '" + name + "'") {} : PyObjectWrapper(obj, "InterpretedObject: \'" + name + "\'") {}
~InterpretedObject() override = default; ~InterpretedObject() override = default;
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper); MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override { abstract::AbstractBasePtr ToAbstract() override {
@ -158,14 +163,10 @@ class SymbolResolver {
// resolve symbol in namespace and save it in result_; // resolve symbol in namespace and save it in result_;
bool Resolve(); bool Resolve();
NameSpacePtr get_namespace() { return namespace_; }
SymbolPtr symbol() { return symbol_; } SymbolPtr symbol() { return symbol_; }
const py::object &result() { return result_; } const py::object &result() { return result_; }
AnfNodePtr resolved_node() { return resolved_node_; }
private: private:
// namespace where the symbol locates // namespace where the symbol locates
NameSpacePtr namespace_; NameSpacePtr namespace_;
@ -177,13 +178,20 @@ class SymbolResolver {
py::object result_; py::object result_;
}; };
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>; using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
// Get python object with index from a list.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node);
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
// Get resolved python object by namespace and symbol.
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node);
// Resolve symbol in namespace. // Resolve symbol in namespace.
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node); const AnfNodePtr &node);
// Resolve Cell with attr name. // Resolve Cell with attr name.
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr); const AnfNodePtr &attr);
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test a list of cell, and getattr by its item """
from mindspore import context, nn, dtype, Tensor
class Actor(nn.Cell):
def __init__(self):
super(Actor, self).__init__()
def act(self, x, y):
return x + y
class Trainer(nn.Cell):
def __init__(self, net_list):
super(Trainer, self).__init__()
self.net_list = net_list
def construct(self, x, y):
return self.net_list[0].act(x, y)
def test_list_item_getattr():
"""
Feature: getattr by the item from list of cell.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor()]
trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
print(trainer(x, y))