forked from mindspore-Ecosystem/mindspore
Support getattr by the item of cell list: Handle the DoSignaturePrimitive('getitem') between getattr and resolve.
This commit is contained in:
parent
c2a47674ad
commit
3e7b73e6c7
|
@ -174,15 +174,15 @@ def resolve_symbol(namespace, symbol):
|
|||
# If need trope the obj
|
||||
if resolve_ in convert_object_map:
|
||||
resolve_ = convert_object_map.get(resolve_)
|
||||
logger.debug("Convert resolve = %r", resolve_)
|
||||
logger.debug("Convert resolve: %r", resolve_)
|
||||
if resolve_ == NO_IMPLEMENT:
|
||||
raise NotImplementedError(f"Not support for '{symbol}'.")
|
||||
except Exception as e:
|
||||
if isinstance(e, NotImplementedError):
|
||||
raise e
|
||||
resolve_ = None
|
||||
logger.debug("Resolve exception occurred, value = %r", e)
|
||||
logger.debug("Resolve type is invalid, namespace = %s, symbol = %s",
|
||||
logger.debug("Resolve exception occurred, value: %r", e)
|
||||
logger.debug("Resolve type is invalid, namespace: %s, symbol: %s",
|
||||
namespace.__str__(), symbol)
|
||||
|
||||
if isinstance(resolve_, _MindsporeFunctionExecutor):
|
||||
|
@ -219,7 +219,7 @@ def get_object_key(obj):
|
|||
if hasattr(obj, "cell_init_args"):
|
||||
obj_key = "%s_ID" % (tag + obj.cell_init_args)
|
||||
obj_id = "%s_ID%d" % (tag, id(obj))
|
||||
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
|
||||
logger.debug("obj_key: %s, obj_id: %s", obj_key, obj_id)
|
||||
|
||||
# method has same id of different instance
|
||||
if isinstance(obj, types.MethodType):
|
||||
|
@ -339,9 +339,17 @@ def create_instance(cls_type, params=None):
|
|||
return obj
|
||||
|
||||
|
||||
def get_obj_from_sequence(obj, index):
|
||||
"""Implement `tuple_getitem`."""
|
||||
if not isinstance(obj, (tuple, list)):
|
||||
raise TypeError(f"Should not get item from a object that not sequence type, obj: {obj}")
|
||||
# Not check index out of range by self.
|
||||
return obj[index]
|
||||
|
||||
|
||||
def get_module_namespace(obj):
|
||||
"""Get the module's namespace."""
|
||||
logger.debug("get module namespace, module = %r", obj)
|
||||
logger.debug("get module namespace, module: %r", obj)
|
||||
mod_namespace = None
|
||||
if isinstance(obj, types.ModuleType):
|
||||
mod_namespace = CellNamespace(obj.__name__)
|
||||
|
@ -352,9 +360,9 @@ def get_module_namespace(obj):
|
|||
|
||||
def get_class_member_namespace_symbol(obj):
|
||||
"""Get obj class member type."""
|
||||
logger.debug("get class instance namespace, object = %r", obj)
|
||||
logger.debug("get class instance namespace, object: %r", obj)
|
||||
class_namespace = ClassMemberNamespace(obj)
|
||||
logger.debug("class namesapce = %r", class_namespace)
|
||||
logger.debug("class namespace: %r", class_namespace)
|
||||
return class_namespace
|
||||
|
||||
|
||||
|
@ -425,14 +433,14 @@ def get_ast_namespace_symbol(obj):
|
|||
"""Get obj type and namespace and symbol."""
|
||||
# step 1:get symbol from object map
|
||||
ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
|
||||
logger.debug("ops info = %r", ops_info)
|
||||
logger.debug("ops info: %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
|
||||
def get_operation_namespace_symbol(var: str):
|
||||
"""Get operation namespace and symbol."""
|
||||
ops_info = (trope_ns, var)
|
||||
logger.debug("get operation ops info = %r", ops_info)
|
||||
logger.debug("get operation ops info: %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
|
||||
|
@ -566,7 +574,7 @@ class Parser:
|
|||
|
||||
def parse(self):
|
||||
"""Parse the function or method."""
|
||||
logger.debug("fn = %r", self.fn)
|
||||
logger.debug("fn: %r", self.fn)
|
||||
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
|
||||
try:
|
||||
lines, self.line_offset = inspect.getsourcelines(self.fn)
|
||||
|
@ -582,7 +590,7 @@ class Parser:
|
|||
src = dedent(original_src)
|
||||
self.col_offset = \
|
||||
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
||||
logger.debug("Get source = %s", src)
|
||||
logger.debug("Get source: %s", src)
|
||||
try:
|
||||
ast_tokens = asttokens.ASTTokens(src, parse=True)
|
||||
except IndentationError as idt_err:
|
||||
|
|
|
@ -176,7 +176,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
|
|||
|
||||
int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMap<AnfNodePtr, int32_t> *para_map) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(INFO) << "Param graph is nullptr.";
|
||||
MS_LOG(INFO) << "Parameter \'graph\' should not be null.";
|
||||
return 0;
|
||||
}
|
||||
std::vector<AnfNodePtr> parameters = graph->parameters();
|
||||
|
@ -217,13 +217,12 @@ int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, Ordere
|
|||
|
||||
void DumpOperator(const AnfNodePtr &node, const std::shared_ptr<SubGraphIRInfo> &gsub) {
|
||||
if (gsub == nullptr) {
|
||||
MS_LOG(INFO) << "Param gsub is nullptr";
|
||||
MS_LOG(INFO) << "Parameter \'gsub\' should not be null.";
|
||||
return;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(INFO) << "Param node should be a CNode";
|
||||
MS_LOG(EXCEPTION) << "Parameter \'node\' should be a CNode";
|
||||
return;
|
||||
}
|
||||
AnfNodePtr op = cnode->input(0);
|
||||
|
@ -272,7 +271,11 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
|
|||
gsub->buffer << ", ";
|
||||
}
|
||||
if (in->isa<Parameter>()) {
|
||||
if (in->func_graph() != node->func_graph()) {
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
if (in->func_graph() == nullptr) {
|
||||
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
|
||||
}
|
||||
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
|
||||
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
|
||||
} else {
|
||||
gsub->buffer << "%";
|
||||
|
@ -283,7 +286,7 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
|
|||
} else {
|
||||
gsub->buffer << "para" << iter->second << "_" << in->ToString();
|
||||
}
|
||||
if (in->func_graph() != node->func_graph()) {
|
||||
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
|
||||
gsub->buffer << ")";
|
||||
}
|
||||
} else if (in->isa<CNode>()) {
|
||||
|
@ -325,7 +328,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
|
|||
}
|
||||
|
||||
ValuePtr in_tmp = MakeValue(in_strategy->GetInputDim());
|
||||
gsub->buffer << " { in_strategy: ";
|
||||
gsub->buffer << " {in_strategy: ";
|
||||
gsub->buffer << in_tmp->ToString();
|
||||
|
||||
auto out_strategy = operator_info->out_strategy();
|
||||
|
@ -335,7 +338,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
|
|||
gsub->buffer << out_tmp->ToString();
|
||||
}
|
||||
|
||||
gsub->buffer << " }";
|
||||
gsub->buffer << "}";
|
||||
}
|
||||
|
||||
void DumpAttrs(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::shared_ptr<SubGraphIRInfo> &gsub,
|
||||
|
|
|
@ -267,9 +267,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
// In resolver_getattr_resolve_, some patterns have priority over others.
|
||||
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
|
||||
// In resolver_, some patterns have priority over others.
|
||||
resolver_ = MakeSubstitution(std::make_shared<Resolver>(), "getattr_resolve",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
|
|
|
@ -169,7 +169,7 @@ class ResolveIRPassLib {
|
|||
public:
|
||||
ResolveIRPassLib();
|
||||
~ResolveIRPassLib() = default;
|
||||
SubstitutionPtr resolver_getattr_resolve_;
|
||||
SubstitutionPtr resolver_;
|
||||
};
|
||||
|
||||
class InferenceOptPrepareLib {
|
||||
|
|
|
@ -434,7 +434,8 @@ class PynativeEliminater : public OptimizerCaller {
|
|||
if (value_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value;
|
||||
auto module_name = GetValueNode<parse::NameSpacePtr>(value_node)->module();
|
||||
return module_name.find(str_value) != std::string::npos;
|
||||
}
|
||||
|
||||
bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) {
|
||||
|
|
|
@ -23,53 +23,92 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
constexpr std::string_view PARSE_SUPER_NAME = "namespace";
|
||||
constexpr size_t namespace_index = 1;
|
||||
constexpr size_t symbol_index = 2;
|
||||
|
||||
PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
|
||||
auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer, &PARSE_SUPER_NAME]() -> AnfNodePtr {
|
||||
auto inner = resolve_node.GetNode(node);
|
||||
AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> getattr_operand, ns_node, sym_node, attr_node, bool_node;
|
||||
auto GetAttrResolveLambda = [&node, &getattr_operand, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto getattr_operand_node = getattr_operand.GetNode(node);
|
||||
auto attr = attr_node.GetNode(node);
|
||||
if (IsPrimitiveCNode(inner, prim::kPrimResolve)) {
|
||||
auto resolve_cnode = inner->cast<CNodePtr>();
|
||||
auto namespace_node = resolve_cnode->input(namespace_index);
|
||||
auto symbol_node = resolve_cnode->input(symbol_index);
|
||||
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
|
||||
return nullptr;
|
||||
constexpr auto recursive_level = 3;
|
||||
MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level);
|
||||
|
||||
// {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}}
|
||||
auto getitem_cnode = getattr_operand_node->cast<CNodePtr>();
|
||||
if (getitem_cnode != nullptr) {
|
||||
constexpr size_t prim_index = 0;
|
||||
auto primitive_node = getitem_cnode->input(prim_index);
|
||||
auto resolved_getitem_node = primitive_node;
|
||||
if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) {
|
||||
auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>();
|
||||
auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2));
|
||||
constexpr auto getitem_symbol = "getitem";
|
||||
if (resolve_getitem_symbol->symbol() == getitem_symbol) {
|
||||
auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1));
|
||||
resolved_getitem_node =
|
||||
ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node);
|
||||
}
|
||||
}
|
||||
// deal with the case of getting attr from a class member
|
||||
// and avoid the case of getting attr from self (the result of ParseSuper)
|
||||
auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node);
|
||||
auto sym = GetValueNode<parse::SymbolPtr>(symbol_node);
|
||||
if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) {
|
||||
return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr);
|
||||
bool is_getattr_getitem = false;
|
||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node));
|
||||
if (do_signature != nullptr) {
|
||||
auto &func_value = do_signature->function();
|
||||
// The function 'func_value' must be the MultitypeFuncGraph of 'getitem'.
|
||||
auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value);
|
||||
constexpr auto getitem_symbol = "getitem";
|
||||
if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) {
|
||||
is_getattr_getitem = true;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) {
|
||||
is_getattr_getitem = true;
|
||||
}
|
||||
if (is_getattr_getitem) {
|
||||
constexpr size_t resolve_index = 1;
|
||||
auto resolve_node = getitem_cnode->input(resolve_index);
|
||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
constexpr size_t position_index = 2;
|
||||
auto index_node = getitem_cnode->input(position_index);
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
||||
auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// {prim::GetAttr, {prim::Resolve, ...}}
|
||||
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(getattr_operand_node);
|
||||
auto module_name = name_space->module();
|
||||
constexpr std::string_view parse_super_name = "namespace";
|
||||
if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
|
||||
symbol->symbol() != parse_super_name) {
|
||||
auto obj = parse::GetSymbolObject(name_space, symbol, node);
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), obj, getattr_operand_node, attr);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
|
||||
parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str);
|
||||
return parse::ResolveSymbol(optimizer->manager(), ns, sym, node);
|
||||
parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(str);
|
||||
return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node);
|
||||
};
|
||||
|
||||
auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
|
||||
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto symbol = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
|
||||
auto manager = optimizer->manager();
|
||||
return parse::ResolveSymbol(manager, ns, sym, node);
|
||||
return parse::ResolveSymbol(manager, name_space, symbol, node);
|
||||
};
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda,
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, getattr_operand, attr_node), GetAttrResolveLambda,
|
||||
attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
|
|
|
@ -38,11 +38,12 @@ namespace irpass {
|
|||
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
|
||||
// The same is true for matching Resolve pattern.
|
||||
//
|
||||
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
class ResolverGetAttrResolve : public OptimizerCaller {
|
||||
class Resolver : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
|
|
@ -525,7 +525,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
|||
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
|
||||
// Check parameter valid
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data is null pointer";
|
||||
MS_LOG(ERROR) << "The value pointer should not be null.";
|
||||
return false;
|
||||
}
|
||||
ValuePtr converted = nullptr;
|
||||
|
@ -554,7 +554,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
|||
ValuePtr value = nullptr;
|
||||
bool is_cache = data_converter::GetObjectValue(obj_id, &value);
|
||||
if (is_cache && value != nullptr && value->isa<FuncGraph>()) {
|
||||
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
|
||||
MS_LOG(DEBUG) << "Get the cache data, obj: " << obj_id;
|
||||
func_graph = value->cast<FuncGraphPtr>();
|
||||
if (!func_graph->dropped()) {
|
||||
return func_graph;
|
||||
|
@ -570,7 +570,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
|||
data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
|
||||
data_converter::CacheObjectValue(obj_id, func_graph);
|
||||
if (!obj_key.empty()) {
|
||||
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
|
||||
MS_LOG(DEBUG) << "Add graph: " << obj_key << ", func_graph: " << func_graph->ToString();
|
||||
data_converter::SetObjGraphValue(obj_key, func_graph);
|
||||
}
|
||||
|
||||
|
@ -584,11 +584,11 @@ static mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> object_graphs_
|
|||
|
||||
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
|
||||
object_graphs_map_[obj_key].push_back(data);
|
||||
MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size();
|
||||
MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
|
||||
}
|
||||
|
||||
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
|
||||
MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size();
|
||||
MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
|
||||
return object_graphs_map_;
|
||||
}
|
||||
|
||||
|
@ -606,7 +606,7 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
|
|||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj);
|
||||
if (obj_tuple.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements";
|
||||
MS_LOG(EXCEPTION) << "The function of \'get_obj_key()\' must return 2 elements";
|
||||
}
|
||||
return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
|
||||
}
|
||||
|
@ -619,10 +619,10 @@ ResolveTypeDef GetObjType(const py::object &obj) {
|
|||
ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast<int32_t>());
|
||||
return obj_type;
|
||||
} catch (const py::error_already_set &ex) {
|
||||
MS_LOG(ERROR) << "Meet a exception from Python when get the type of `" << py::str(obj) << "`.\n" << ex.what();
|
||||
MS_LOG(ERROR) << "Meet a exception from Python when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
|
||||
std::rethrow_exception(std::current_exception());
|
||||
} catch (const py::type_error &ex) {
|
||||
MS_LOG(ERROR) << "Meet a exception when get the type of `" << py::str(obj) << "`.\n" << ex.what();
|
||||
MS_LOG(ERROR) << "Meet a exception when get the type of \'" << py::str(obj) << "\'.\n" << ex.what();
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
}
|
||||
|
@ -638,8 +638,8 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
|
|||
// Check the object is Cell Instance.
|
||||
bool IsCellInstance(const py::object &obj) {
|
||||
auto class_type = GetClassInstanceType(obj);
|
||||
bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
|
||||
return isCell;
|
||||
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
|
||||
return is_cell;
|
||||
}
|
||||
|
||||
// Create the python class instance.
|
||||
|
|
|
@ -27,12 +27,12 @@
|
|||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
// define the node type
|
||||
// Define the node type.
|
||||
enum AstMainType : int64_t {
|
||||
AST_MAIN_TYPE_STMT = 0, // ast.Stmt
|
||||
AST_MAIN_TYPE_EXPR = 1, // ast.Expr
|
||||
AST_MAIN_TYPE_SLICE = 2, // ast.Slice
|
||||
AST_MAIN_TYPE_UNKNOWN = 0xFF // Error
|
||||
AST_MAIN_TYPE_UNKNOWN = 0xFF // Unknown type
|
||||
};
|
||||
|
||||
enum AstSubType : int64_t {
|
||||
|
@ -43,18 +43,18 @@ enum AstSubType : int64_t {
|
|||
AST_SUB_TYPE_SUBSCRIPT = 7, // ast.Subscript
|
||||
AST_SUB_TYPE_STARRED = 8, // ast.Starred
|
||||
AST_SUB_TYPE_ATTRIBUTE = 9, // ast.Attribute
|
||||
AST_SUB_TYPE_UNKNOWN = 0xFF // Error
|
||||
AST_SUB_TYPE_UNKNOWN = 0xFF // Unknown type
|
||||
};
|
||||
|
||||
// define the parse target type
|
||||
// Define the parse target type.
|
||||
enum ParseTargetTypeDef {
|
||||
PARSE_TARGET_FUNCTION = 0, // function
|
||||
PARSE_TARGET_METHOD = 1, // method
|
||||
PARSE_TARGET_OBJECT_INSTANCE = 2, // object instance
|
||||
PARSE_TARGET_UNKNOW = 0xFF // ERROR TYPE
|
||||
PARSE_TARGET_FUNCTION = 0, // Function
|
||||
PARSE_TARGET_METHOD = 1, // Method
|
||||
PARSE_TARGET_OBJECT_INSTANCE = 2, // Object instance
|
||||
PARSE_TARGET_UNKNOW = 0xFF // Unknown type
|
||||
};
|
||||
|
||||
// define python module name
|
||||
// Define python module name.
|
||||
const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse";
|
||||
const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb";
|
||||
const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
|
||||
|
@ -72,6 +72,7 @@ const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespac
|
|||
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||
const char PYTHON_MOD_GET_ITEM_FROM_SEQUENCE[] = "get_obj_from_sequence";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
||||
|
||||
|
@ -92,7 +93,7 @@ const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
|
|||
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
|
||||
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
||||
|
||||
// define the common name
|
||||
// Define the common name.
|
||||
const char NAMED_PRIMITIVE_LEN[] = "len";
|
||||
const char NAMED_PRIMITIVE_BODY[] = "body";
|
||||
const char NAMED_PRIMITIVE_ASSIGN[] = "Assign";
|
||||
|
@ -127,8 +128,8 @@ const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";
|
|||
const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict";
|
||||
const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call";
|
||||
|
||||
// define NAMED_PRIMITIVE_GETATTR "getattr"
|
||||
// define python inline attr
|
||||
// Define NAMED_PRIMITIVE_GETATTR "getattr".
|
||||
// Define python inline attr.
|
||||
const char PYTHON_GET_METHOD_LEN[] = "__len__";
|
||||
const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__";
|
||||
const char PYTHON_GET_OBJ_DESC[] = "__str__";
|
||||
|
@ -136,46 +137,46 @@ const char PYTHON_GET_OBJ_DESC[] = "__str__";
|
|||
const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__";
|
||||
const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
|
||||
|
||||
// define the parse constant
|
||||
// Define the parse constant.
|
||||
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
|
||||
const char CUSTOM_BPROP_NAME[] = "bprop";
|
||||
const char STAGE_NAME[] = "_pipeline_stage";
|
||||
|
||||
// define the Namespace name
|
||||
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace
|
||||
const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // for class member namespace
|
||||
const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // for symbol str namespace
|
||||
const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // for common ops, eg: hasnext, next
|
||||
const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // fro Module namespace
|
||||
// Define the Namespace name.
|
||||
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // For ast type namespace.
|
||||
const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // For class member namespace.
|
||||
const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // For symbol str namespace.
|
||||
const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // For common ops, eg: hasnext, next.
|
||||
const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // For Module namespace.
|
||||
|
||||
// define Resolve type
|
||||
// Define Resolve type.
|
||||
enum ResolveTypeDef : int64_t {
|
||||
RESOLVE_TYPE_NONE = 0, // resolve None
|
||||
RESOLVE_TYPE_FUNCTION = 1, // resolve function
|
||||
RESOLVE_TYPE_METHOD = 2, // resolve class method
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3, // resolve class type
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4, // resolve the class instance of common class
|
||||
RESOLVE_TYPE_INVALID = 0xFF // resolve invalid
|
||||
RESOLVE_TYPE_NONE = 0, // Resolve None
|
||||
RESOLVE_TYPE_FUNCTION = 1, // Resolve function
|
||||
RESOLVE_TYPE_METHOD = 2, // Resolve class method
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class
|
||||
RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid
|
||||
};
|
||||
|
||||
// define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE
|
||||
// Define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE.
|
||||
enum ClassInstanceTypeDef {
|
||||
CLASS_INSTANCE_TYPE_CELL = 0, // class instance type is Cell
|
||||
CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // class instance type is Primitive
|
||||
CLASS_INSTANCE_TYPE_CELL = 0, // Class instance type is Cell.
|
||||
CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // Class instance type is Primitive.
|
||||
CLASS_INSTANCE_TYPE_INVALID = 0xFF
|
||||
};
|
||||
|
||||
// Convert python object to ValuePtr
|
||||
// Convert python object to ValuePtr.
|
||||
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr);
|
||||
|
||||
// Convert python obj to graph
|
||||
// Convert python obj to graph.
|
||||
FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
|
||||
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
|
||||
|
||||
// Parse the python object to graph
|
||||
// Parse the python object to graph.
|
||||
FuncGraphPtr ParsePythonCode(const py::object &obj,
|
||||
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
|
||||
// add wrap for cell top graph.
|
||||
// Add wrap for cell top graph.
|
||||
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr);
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,8 +16,9 @@
|
|||
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/param_info.h"
|
||||
|
@ -310,25 +311,12 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
|
|||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node) {
|
||||
// Get python object with index from a list.
|
||||
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
if (node->func_graph() == nullptr || manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
symbol_resolver.Resolve();
|
||||
py::object obj = symbol_resolver.result();
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
|
||||
return resolved_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
if (node->func_graph() == nullptr || manager == nullptr) {
|
||||
if (node->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
|
@ -337,6 +325,75 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
|
|||
}
|
||||
|
||||
py::object obj = symbol_resolver.result();
|
||||
if (!py::isinstance<py::list>(obj) && !py::isinstance<py::tuple>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj);
|
||||
}
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
int index = GetValueNode<Int64ImmPtr>(index_node)->value();
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index;
|
||||
py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index));
|
||||
return item_obj;
|
||||
}
|
||||
|
||||
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimResolve)) {
|
||||
auto resolve_cnode = node->cast<CNodePtr>();
|
||||
constexpr size_t namespace_index = 1;
|
||||
auto namespace_node = resolve_cnode->input(namespace_index);
|
||||
constexpr size_t symbol_index = 2;
|
||||
auto symbol_node = resolve_cnode->input(symbol_index);
|
||||
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected type, namespace: " << namespace_node->ToString()
|
||||
<< ", symbol: " << symbol_node->ToString();
|
||||
}
|
||||
// Deal with the case of GetAttr from a class member,
|
||||
// and avoid the case of GetAttr from self (the result of ParseSuper).
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(namespace_node);
|
||||
auto symbol = GetValueNode<parse::SymbolPtr>(symbol_node);
|
||||
return {name_space, symbol};
|
||||
}
|
||||
constexpr auto recursive_level = 2;
|
||||
MS_LOG(EXCEPTION) << "It's not prim::Resolve CNode, node: " << node->DebugString(recursive_level);
|
||||
}
|
||||
|
||||
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
symbol_resolver.Resolve();
|
||||
if (!symbol_resolver.Resolve()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to resolve node, NodeInfo.";
|
||||
}
|
||||
py::object obj = symbol_resolver.result();
|
||||
return obj;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Manager is nullptr.";
|
||||
}
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
auto obj = GetSymbolObject(name_space, symbol, node);
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
|
||||
return resolved_node;
|
||||
}
|
||||
|
||||
// Resolve Cell GetAttr operation.
|
||||
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
|
||||
const AnfNodePtr &attr) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Manager is nullptr.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", attr: " << attr->ToString();
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
if (!data_converter::IsCellInstance(obj)) {
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
|
||||
|
@ -365,7 +422,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &ir
|
|||
opt::OptPassGroupMap map({
|
||||
{"resolve",
|
||||
{
|
||||
irpass.resolver_getattr_resolve_,
|
||||
irpass.resolver_,
|
||||
}},
|
||||
});
|
||||
return map;
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_RESOLVE_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/manager.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
@ -40,7 +42,10 @@ namespace parse {
|
|||
class NameSpace final : public Named {
|
||||
public:
|
||||
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
|
||||
: Named(module), module_(module), obj_(obj), module_obj_(module_obj) {}
|
||||
: Named(module + ": \'" + std::string(py::str(obj)) + "\'"),
|
||||
module_(module),
|
||||
obj_(obj),
|
||||
module_obj_(module_obj) {}
|
||||
~NameSpace() override = default;
|
||||
MS_DECLARE_PARENT(NameSpace, Named);
|
||||
|
||||
|
@ -92,7 +97,7 @@ class Script final : public Named {
|
|||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
|
||||
}
|
||||
std::string ToString() const override { return "`" + name() + "`"; }
|
||||
std::string ToString() const override { return "\'" + name() + "\'"; }
|
||||
|
||||
private:
|
||||
std::string script_;
|
||||
|
@ -116,7 +121,7 @@ class PyObjectWrapper : public Named {
|
|||
class InterpretedObject final : public PyObjectWrapper {
|
||||
public:
|
||||
explicit InterpretedObject(const py::object &obj, const std::string &name = "null")
|
||||
: PyObjectWrapper(obj, "InterpretedObject: '" + name + "'") {}
|
||||
: PyObjectWrapper(obj, "InterpretedObject: \'" + name + "\'") {}
|
||||
~InterpretedObject() override = default;
|
||||
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
|
@ -158,14 +163,10 @@ class SymbolResolver {
|
|||
// resolve symbol in namespace and save it in result_;
|
||||
bool Resolve();
|
||||
|
||||
NameSpacePtr get_namespace() { return namespace_; }
|
||||
|
||||
SymbolPtr symbol() { return symbol_; }
|
||||
|
||||
const py::object &result() { return result_; }
|
||||
|
||||
AnfNodePtr resolved_node() { return resolved_node_; }
|
||||
|
||||
private:
|
||||
// namespace where the symbol locates
|
||||
NameSpacePtr namespace_;
|
||||
|
@ -177,13 +178,20 @@ class SymbolResolver {
|
|||
py::object result_;
|
||||
};
|
||||
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
|
||||
|
||||
// Get python object with index from a list.
|
||||
py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node,
|
||||
const AnfNodePtr &index_node);
|
||||
std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node);
|
||||
|
||||
// Get resolved python object by namespace and symbol.
|
||||
py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node);
|
||||
// Resolve symbol in namespace.
|
||||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node);
|
||||
|
||||
// Resolve Cell with attr name.
|
||||
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr);
|
||||
AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node,
|
||||
const AnfNodePtr &attr);
|
||||
|
||||
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
|
||||
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test a list of cell, and getattr by its item """
|
||||
from mindspore import context, nn, dtype, Tensor
|
||||
|
||||
|
||||
class Actor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
def act(self, x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
class Trainer(nn.Cell):
|
||||
def __init__(self, net_list):
|
||||
super(Trainer, self).__init__()
|
||||
self.net_list = net_list
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net_list[0].act(x, y)
|
||||
|
||||
|
||||
def test_list_item_getattr():
|
||||
"""
|
||||
Feature: getattr by the item from list of cell.
|
||||
Description: Support RL use method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actor_list = [Actor()]
|
||||
trainer = Trainer(actor_list)
|
||||
x = Tensor([3], dtype=dtype.float32)
|
||||
y = Tensor([6], dtype=dtype.float32)
|
||||
print(trainer(x, y))
|
Loading…
Reference in New Issue