Support getattr by the item of cell list: Handle the DoSignaturePrimitive('getitem') between getattr and resolve.

This commit is contained in:
Zhang Qinghua 2021-12-10 14:25:18 +08:00
parent c2a47674ad
commit 3e7b73e6c7
12 changed files with 289 additions and 124 deletions

View File

@ -174,15 +174,15 @@ def resolve_symbol(namespace, symbol):
# If need trope the obj
if resolve_ in convert_object_map:
resolve_ = convert_object_map.get(resolve_)
logger.debug("Convert resolve = %r", resolve_)
logger.debug("Convert resolve: %r", resolve_)
if resolve_ == NO_IMPLEMENT:
raise NotImplementedError(f"Not support for '{symbol}'.")
except Exception as e:
if isinstance(e, NotImplementedError):
raise e
resolve_ = None
logger.debug("Resolve exception occurred, value = %r", e)
logger.debug("Resolve type is invalid, namespace = %s, symbol = %s",
logger.debug("Resolve exception occurred, value: %r", e)
logger.debug("Resolve type is invalid, namespace: %s, symbol: %s",
namespace.__str__(), symbol)
if isinstance(resolve_, _MindsporeFunctionExecutor):
@ -219,7 +219,7 @@ def get_object_key(obj):
if hasattr(obj, "cell_init_args"):
obj_key = "%s_ID" % (tag + obj.cell_init_args)
obj_id = "%s_ID%d" % (tag, id(obj))
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
logger.debug("obj_key: %s, obj_id: %s", obj_key, obj_id)
# method has same id of different instance
if isinstance(obj, types.MethodType):
@ -339,9 +339,17 @@ def create_instance(cls_type, params=None):
return obj
def get_obj_from_sequence(obj, index):
"""Implement `tuple_getitem`."""
if not isinstance(obj, (tuple, list)):
raise TypeError(f"Should not get item from a object that not sequence type, obj: {obj}")
# Not check index out of range by self.
return obj[index]
def get_module_namespace(obj):
"""Get the module's namespace."""
logger.debug("get module namespace, module = %r", obj)
logger.debug("get module namespace, module: %r", obj)
mod_namespace = None
if isinstance(obj, types.ModuleType):
mod_namespace = CellNamespace(obj.__name__)
@ -352,9 +360,9 @@ def get_module_namespace(obj):
def get_class_member_namespace_symbol(obj):
"""Get obj class member type."""
logger.debug("get class instance namespace, object = %r", obj)
logger.debug("get class instance namespace, object: %r", obj)
class_namespace = ClassMemberNamespace(obj)
logger.debug("class namesapce = %r", class_namespace)
logger.debug("class namespace: %r", class_namespace)
return class_namespace
@ -425,14 +433,14 @@ def get_ast_namespace_symbol(obj):
"""Get obj type and namespace and symbol."""
# step 1:get symbol from object map
ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
logger.debug("ops info = %r", ops_info)
logger.debug("ops info: %r", ops_info)
return ops_info
def get_operation_namespace_symbol(var: str):
"""Get operation namespace and symbol."""
ops_info = (trope_ns, var)
logger.debug("get operation ops info = %r", ops_info)
logger.debug("get operation ops info: %r", ops_info)
return ops_info
@ -566,7 +574,7 @@ class Parser:
def parse(self):
"""Parse the function or method."""
logger.debug("fn = %r", self.fn)
logger.debug("fn: %r", self.fn)
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
try:
lines, self.line_offset = inspect.getsourcelines(self.fn)
@ -582,7 +590,7 @@ class Parser:
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("Get source = %s", src)
logger.debug("Get source: %s", src)
try:
ast_tokens = asttokens.ASTTokens(src, parse=True)
except IndentationError as idt_err:

View File

@ -176,7 +176,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMap<AnfNodePtr, int32_t> *para_map) {
if (graph == nullptr) {
MS_LOG(INFO) << "Param graph is nullptr.";
MS_LOG(INFO) << "Parameter \'graph\' should not be null.";
return 0;
}
std::vector<AnfNodePtr> parameters = graph->parameters();
@ -217,13 +217,12 @@ int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, Ordere
void DumpOperator(const AnfNodePtr &node, const std::shared_ptr<SubGraphIRInfo> &gsub) {
if (gsub == nullptr) {
MS_LOG(INFO) << "Param gsub is nullptr";
MS_LOG(INFO) << "Parameter \'gsub\' should not be null.";
return;
}
auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
if (cnode == nullptr) {
MS_LOG(INFO) << "Param node should be a CNode";
MS_LOG(EXCEPTION) << "Parameter \'node\' should be a CNode";
return;
}
AnfNodePtr op = cnode->input(0);
@ -272,7 +271,11 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
gsub->buffer << ", ";
}
if (in->isa<Parameter>()) {
if (in->func_graph() != node->func_graph()) {
MS_EXCEPTION_IF_NULL(node->func_graph());
if (in->func_graph() == nullptr) {
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
}
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
} else {
gsub->buffer << "%";
@ -283,7 +286,7 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
} else {
gsub->buffer << "para" << iter->second << "_" << in->ToString();
}
if (in->func_graph() != node->func_graph()) {
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << ")";
}
} else if (in->isa<CNode>()) {
@ -325,7 +328,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
}
ValuePtr in_tmp = MakeValue(in_strategy->GetInputDim());
gsub->buffer << " { in_strategy: ";
gsub->buffer << " {in_strategy: ";
gsub->buffer << in_tmp->ToString();
auto out_strategy = operator_info->out_strategy();
@ -335,7 +338,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
gsub->buffer << out_tmp->ToString();
}
gsub->buffer << " }";
gsub->buffer << "}";
}
void DumpAttrs(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::shared_ptr<SubGraphIRInfo> &gsub,

View File

@ -267,9 +267,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
}
ResolveIRPassLib::ResolveIRPassLib() {
// In resolver_getattr_resolve_, some patterns have priority over others.
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
// In resolver_, some patterns have priority over others.
resolver_ = MakeSubstitution(std::make_shared<Resolver>(), "getattr_resolve",
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {

View File

@ -169,7 +169,7 @@ class ResolveIRPassLib {
public:
ResolveIRPassLib();
~ResolveIRPassLib() = default;
SubstitutionPtr resolver_getattr_resolve_;
SubstitutionPtr resolver_;
};
class InferenceOptPrepareLib {

View File

@ -434,7 +434,8 @@ class PynativeEliminater : public OptimizerCaller {
if (value_node == nullptr) {
return false;
}
return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value;
auto module_name = GetValueNode<parse::NameSpacePtr>(value_node)->module();
return module_name.find(str_value) != std::string::npos;
}
bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) {

View File

@ -23,53 +23,92 @@
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, bool, attr}
// {prim::kPrimResolve, namespace, symbol}
AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
constexpr std::string_view PARSE_SUPER_NAME = "namespace";
constexpr size_t namespace_index = 1;
constexpr size_t symbol_index = 2;
PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer, &PARSE_SUPER_NAME]() -> AnfNodePtr {
auto inner = resolve_node.GetNode(node);
AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> getattr_operand, ns_node, sym_node, attr_node, bool_node;
auto GetAttrResolveLambda = [&node, &getattr_operand, &attr_node, &optimizer]() -> AnfNodePtr {
auto getattr_operand_node = getattr_operand.GetNode(node);
auto attr = attr_node.GetNode(node);
if (IsPrimitiveCNode(inner, prim::kPrimResolve)) {
auto resolve_cnode = inner->cast<CNodePtr>();
auto namespace_node = resolve_cnode->input(namespace_index);
auto symbol_node = resolve_cnode->input(symbol_index);
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
return nullptr;
constexpr auto recursive_level = 3;
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}}
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
if (getitem_cnode != nullptr) {
constexpr size_t prim_index = 0;
auto primitive_node = getitem_cnode->input(prim_index);
auto resolved_getitem_node = primitive_node;
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) {
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>();
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
constexpr auto getitem_symbol = "getitem";
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
resolved_getitem_node =
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
}
}
// deal with the case of getting attr from a class member
// and avoid the case of getting attr from self (the result of ParseSuper)
auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node);
auto sym = GetValueNode<parse::SymbolPtr>(symbol_node);
if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) {
return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr);
bool is_getattr_getitem = false;
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
if (do_signature != nullptr) {
auto &func_value = do_signature->function();
// The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
constexpr auto getitem_symbol = "getitem";
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
is_getattr_getitem = true;
}
}
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
is_getattr_getitem = true;
}
if (is_getattr_getitem) {
constexpr size_t resolve_index = 1;
auto resolve_node = getitem_cnode->input(resolve_index);
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
constexpr size_t position_index = 2;
auto index_node = getitem_cnode->input(position_index);
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node);
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr);
}
}
}
// {prim::GetAttr, {prim::Resolve, ...}}
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimResolve)) {
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(getattr_operand_node);
auto module_name = name_space->module();
constexpr std::string_view parse_super_name = "namespace";
if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
symbol->symbol() != parse_super_name) {
auto obj = parse::GetSymbolObject(name_space, symbol, node);
return parse::ResolveCellWithAttr(optimizer->manager(), obj, getattr_operand_node, attr);
}
}
return nullptr;
};
auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str);
return parse::ResolveSymbol(optimizer->manager(), ns, sym, node);
parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(str);
return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node);
};
auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto symbol = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
auto manager = optimizer->manager();
return parse::ResolveSymbol(manager, ns, sym, node);
return parse::ResolveSymbol(manager, name_space, symbol, node);
};
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda,
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, getattr_operand, attr_node), GetAttrResolveLambda,
attr_node.CheckFunc(IsValueNode<StringImm>, node));
// {prim::kPrimGetAttr, namespace, attr}
MATCH_REPLACE_LAMBDA_IF(

View File

@ -38,11 +38,12 @@ namespace irpass {
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
// The same is true for matching Resolve pattern.
//
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, bool, attr}
// {prim::kPrimResolve, namespace, symbol}
class ResolverGetAttrResolve : public OptimizerCaller {
class Resolver : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};

View File

@ -525,7 +525,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
// Check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
MS_LOG(ERROR) << "The value pointer should not be null.";
return false;
}
ValuePtr converted = nullptr;
@ -554,7 +554,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
ValuePtr value = nullptr;
bool is_cache = data_converter::GetObjectValue(obj_id, &value);
if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_id;
func_graph = value->cast<FuncGraphPtr>();
if (!func_graph->dropped()) {
return func_graph;
@ -570,7 +570,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
data_converter::CacheObjectValue(obj_id, func_graph);
if (!obj_key.empty()) {
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
MS_LOG(DEBUG) << "Add graph: " << obj_key << ", func_graph: " << func_graph->ToString();
data_converter::SetObjGraphValue(obj_key, func_graph);
}
@ -584,11 +584,11 @@ static mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> object_graphs_
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
object_graphs_map_[obj_key].push_back(data);
MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
}
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
return object_graphs_map_;
}
@ -606,7 +606,7 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
if (obj_tuple.size() != 2) {
MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
MS_LOG(EXCEPTION) << "The function of \'get_obj_key()\' must return 2 elements";
}
return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
}
@ -619,10 +619,10 @@ ResolveTypeDef GetObjType(const py::object &obj) {
ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
return obj_type;
} catch (const py::error_already_set &ex) {
MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
MS_LOG(ERROR) << "Meet a exception from Python when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
std::rethrow_exception(std::current_exception());
} catch (const py::type_error &ex) {
MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
MS_LOG(ERROR) << "Meet a exception when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
std::rethrow_exception(std::current_exception());
}
}
@ -638,8 +638,8 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
// Check the object is Cell Instance.
bool IsCellInstance(const py::object &obj) {
auto class_type = GetClassInstanceType(obj);
bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
return isCell;
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
return is_cell;
}
// Create the python class instance.

View File

@ -27,12 +27,12 @@
namespace py = pybind11;
namespace mindspore {
namespace parse {
// define the node type
// Define the node type.
enum AstMainType : int64_t {
AST_MAIN_TYPE_STMT = 0, // ast.Stmt
AST_MAIN_TYPE_EXPR = 1, // ast.Expr
AST_MAIN_TYPE_SLICE = 2, // ast.Slice
AST_MAIN_TYPE_UNKNOWN = 0xFF // Error
AST_MAIN_TYPE_UNKNOWN = 0xFF // Unknown type
};
enum AstSubType : int64_t {
@ -43,18 +43,18 @@ enum AstSubType : int64_t {
AST_SUB_TYPE_SUBSCRIPT = 7, // ast.Subscript
AST_SUB_TYPE_STARRED = 8, // ast.Starred
AST_SUB_TYPE_ATTRIBUTE = 9, // ast.Attribute
AST_SUB_TYPE_UNKNOWN = 0xFF // Error
AST_SUB_TYPE_UNKNOWN = 0xFF // Unknown type
};
// define the parse target type
// Define the parse target type.
enum ParseTargetTypeDef {
PARSE_TARGET_FUNCTION = 0, // function
PARSE_TARGET_METHOD = 1, // method
PARSE_TARGET_OBJECT_INSTANCE = 2, // object instance
PARSE_TARGET_UNKNOW = 0xFF // ERROR TYPE
PARSE_TARGET_FUNCTION = 0, // Function
PARSE_TARGET_METHOD = 1, // Method
PARSE_TARGET_OBJECT_INSTANCE = 2, // Object instance
PARSE_TARGET_UNKNOW = 0xFF // Unknown type
};
// define python module name
// Define python module name.
const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb";
const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
@ -72,6 +72,7 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
@ -92,7 +93,7 @@ const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
// define the common name
// Define the common name.
const char NAMED_PRIMITIVE_LEN[] = "len";
const char NAMED_PRIMITIVE_BODY[] = "body";
const char NAMED_PRIMITIVE_ASSIGN[] = "Assign";
@ -127,8 +128,8 @@ const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";
const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict";
const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call";
// define NAMED_PRIMITIVE_GETATTR "getattr"
// define python inline attr
// Define NAMED_PRIMITIVE_GETATTR "getattr".
// Define python inline attr.
const char PYTHON_GET_METHOD_LEN[] = "__len__";
const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__";
const char PYTHON_GET_OBJ_DESC[] = "__str__";
@ -136,46 +137,46 @@ const char PYTHON_GET_OBJ_DESC[] = "__str__";
const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__";
const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant
// Define the parse constant.
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";
const char STAGE_NAME[] = "_pipeline_stage";
// define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace
const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // for class member namespace
const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // for symbol str namespace
const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // for common ops, eg: hasnext, next
const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // fro Module namespace
// Define the Namespace name.
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // For ast type namespace.
const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // For class member namespace.
const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // For symbol str namespace.
const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // For common ops, eg: hasnext, next.
const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // For Module namespace.
// define Resolve type
// Define Resolve type.
enum ResolveTypeDef : int64_t {
RESOLVE_TYPE_NONE = 0, // resolve None
RESOLVE_TYPE_FUNCTION = 1, // resolve function
RESOLVE_TYPE_METHOD = 2, // resolve class method
RESOLVE_TYPE_CLASS_TYPE = 3, // resolve class type
RESOLVE_TYPE_CLASS_INSTANCE = 4, // resolve the class instance of common class
RESOLVE_TYPE_INVALID = 0xFF // resolve invalid
RESOLVE_TYPE_NONE = 0, // Resolve None
RESOLVE_TYPE_FUNCTION = 1, // Resolve function
RESOLVE_TYPE_METHOD = 2, // Resolve class method
RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type
RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class
RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid
};
// define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE
// Define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE.
enum ClassInstanceTypeDef {
CLASS_INSTANCE_TYPE_CELL = 0, // class instance type is Cell
CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // class instance type is Primitive
CLASS_INSTANCE_TYPE_CELL = 0, // Class instance type is Cell.
CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // Class instance type is Primitive.
CLASS_INSTANCE_TYPE_INVALID = 0xFF
};
// Convert python object to ValuePtr
// Convert python object to ValuePtr.
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr);
// Convert python obj to graph
// Convert python obj to graph.
FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
// Parse the python object to graph
// Parse the python object to graph.
FuncGraphPtr ParsePythonCode(const py::object &obj,
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
// add wrap for cell top graph.
// Add wrap for cell top graph.
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr);
} // namespace parse
} // namespace mindspore

View File

@ -16,8 +16,9 @@
#include "pipeline/jit/parse/resolve.h"
#include <string>
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include "ir/param_info.h"
@ -310,25 +311,12 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
}
} // namespace
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node) {
// Get python object with index from a list.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
symbol_resolver.Resolve();
py::object obj = symbol_resolver.result();
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
return resolved_node;
}
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (node->func_graph() == nullptr || manager == nullptr) {
if (node->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
@ -337,6 +325,75 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
}
py::object obj = symbol_resolver.result();
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
}
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
const std::string module = "mindspore._extends.parse.parser";
int index = GetValueNode<Int64ImmPtr>(index_node)->value();
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
return item_obj;
}
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
auto resolve_cnode = node->cast<CNodePtr>();
constexpr size_t namespace_index = 1;
auto namespace_node = resolve_cnode->input(namespace_index);
constexpr size_t symbol_index = 2;
auto symbol_node = resolve_cnode->input(symbol_index);
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
<< ", symbol: " << symbol_node->ToString();
}
// Deal with the case of GetAttr from a class member,
// and avoid the case of GetAttr from self (the result of ParseSuper).
auto name_space = GetValueNode<parse::NameSpacePtr>(namespace_node);
auto symbol = GetValueNode<parse::SymbolPtr>(symbol_node);
return {name_space, symbol};
}
constexpr auto recursive_level = 2;
MS_LOG(EXCEPTION) << "It's not prim::Resolve CNode, node: " << node->DebugString(recursive_level);
}
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
symbol_resolver.Resolve();
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Fail to resolve node, NodeInfo.";
}
py::object obj = symbol_resolver.result();
return obj;
}
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Manager is nullptr.";
}
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
auto obj = GetSymbolObject(name_space, symbol, node);
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
return resolved_node;
}
// Resolve Cell GetAttr operation.
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
const AnfNodePtr &attr) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(attr);
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Manager is nullptr.";
}
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", attr: " << attr->ToString();
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (!data_converter::IsCellInstance(obj)) {
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
@ -365,7 +422,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &ir
opt::OptPassGroupMap map({
{"resolve",
{
irpass.resolver_getattr_resolve_,
irpass.resolver_,
}},
});
return map;

View File

@ -17,8 +17,10 @@
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_
#include <utility>
#include <memory>
#include <string>
#include "ir/anf.h"
#include "ir/manager.h"
#include "pipeline/jit/parse/python_adapter.h"
@ -40,7 +42,10 @@ namespace parse {
class NameSpace final : public Named {
public:
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
: Named(module), module_(module), obj_(obj), module_obj_(module_obj) {}
: Named(module + ": \'" + std::string(py::str(obj)) + "\'"),
module_(module),
obj_(obj),
module_obj_(module_obj) {}
~NameSpace() override = default;
MS_DECLARE_PARENT(NameSpace, Named);
@ -92,7 +97,7 @@ class Script final : public Named {
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
}
std::string ToString() const override { return "`" + name() + "`"; }
std::string ToString() const override { return "\'" + name() + "\'"; }
private:
std::string script_;
@ -116,7 +121,7 @@ class PyObjectWrapper : public Named {
class InterpretedObject final : public PyObjectWrapper {
public:
explicit InterpretedObject(const py::object &obj, const std::string &name = "null")
: PyObjectWrapper(obj, "InterpretedObject: '" + name + "'") {}
: PyObjectWrapper(obj, "InterpretedObject: \'" + name + "\'") {}
~InterpretedObject() override = default;
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override {
@ -158,14 +163,10 @@ class SymbolResolver {
// resolve symbol in namespace and save it in result_;
bool Resolve();
NameSpacePtr get_namespace() { return namespace_; }
SymbolPtr symbol() { return symbol_; }
const py::object &result() { return result_; }
AnfNodePtr resolved_node() { return resolved_node_; }
private:
// namespace where the symbol locates
NameSpacePtr namespace_;
@ -177,13 +178,20 @@ class SymbolResolver {
py::object result_;
};
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
// Get python object with index from a list.
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
const AnfNodePtr &index_node);
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
// Get resolved python object by namespace and symbol.
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node);
// Resolve symbol in namespace.
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node);
// Resolve Cell with attr name.
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr);
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
const AnfNodePtr &attr);
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test a list of cell, and getattr by its item """
from mindspore import context, nn, dtype, Tensor
class Actor(nn.Cell):
def __init__(self):
super(Actor, self).__init__()
def act(self, x, y):
return x + y
class Trainer(nn.Cell):
def __init__(self, net_list):
super(Trainer, self).__init__()
self.net_list = net_list
def construct(self, x, y):
return self.net_list[0].act(x, y)
def test_list_item_getattr():
"""
Feature: getattr by the item from list of cell.
Description: Support RL use method in graph mode.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
actor_list = [Actor()]
trainer = Trainer(actor_list)
x = Tensor([3], dtype=dtype.float32)
y = Tensor([6], dtype=dtype.float32)
print(trainer(x, y))