forked from mindspore-Ecosystem/mindspore
!3237 support call the parent class function
Merge pull request !3237 from zhangbuxue/support_call_the_parent_class_construct_function
This commit is contained in:
commit
8da91ca3cf
|
@ -99,12 +99,19 @@ class ClassMemberNamespace(Namespace):
|
|||
obj (Object): A python class object.
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
self.__class_member_namespace__ = True
|
||||
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
|
||||
super().__init__(label, obj)
|
||||
|
||||
def __getitem__(self, name):
|
||||
d, = self.dicts
|
||||
if name == "self":
|
||||
return d
|
||||
if name == "namespace":
|
||||
return self
|
||||
try:
|
||||
return getattr(d, name)
|
||||
if hasattr(d, name):
|
||||
return getattr(d, name)
|
||||
return d.__dict__[name]
|
||||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
||||
|
|
|
@ -70,6 +70,7 @@ parse_expr_statement_white_list = (
|
|||
"append",
|
||||
)
|
||||
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create slice object"""
|
||||
return slice(start, end, step)
|
||||
|
@ -201,9 +202,10 @@ def get_object_key(obj):
|
|||
if isinstance(obj, types.MethodType):
|
||||
method_instance = obj.__self__
|
||||
instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
|
||||
obj_id = instance_id + obj_id
|
||||
obj_id = instance_id + obj_id + str(obj.__hash__())
|
||||
return obj_id, obj_key
|
||||
|
||||
|
||||
def get_default_input(obj):
|
||||
if hasattr(obj, '__parameter__'):
|
||||
return obj.default_input
|
||||
|
@ -213,6 +215,7 @@ def get_default_input(obj):
|
|||
return args
|
||||
return obj
|
||||
|
||||
|
||||
def is_class_member(node):
|
||||
"""Check the attr is class member variable."""
|
||||
type_ = node.__class__.__name__
|
||||
|
@ -224,10 +227,12 @@ def is_class_member(node):
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_obj_id(obj):
|
||||
"""Get the obj id."""
|
||||
return str(id(obj))
|
||||
|
||||
|
||||
def get_obj_type(obj):
|
||||
"""Get the obj type."""
|
||||
obj_type = RESOLVE_TYPE_INVALID
|
||||
|
@ -320,6 +325,7 @@ def get_dataclass_methods(cls):
|
|||
if isinstance(getattr(cls, name), (types.FunctionType,))}
|
||||
return methods
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
@ -453,6 +459,28 @@ class Parser:
|
|||
logger.debug("ops info = %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
def analyze_super(self, father_class_node, subclass_instance):
|
||||
"""Analyze super and return a class instance."""
|
||||
father_class = None
|
||||
if father_class_node is None:
|
||||
father_class = type(subclass_instance)
|
||||
if isinstance(father_class_node, ast.Name):
|
||||
father_class_name = getattr(father_class_node, 'id')
|
||||
father_class = self.global_namespace[father_class_name]
|
||||
if isinstance(father_class_node, ast.Attribute):
|
||||
value = getattr(father_class_node, 'value')
|
||||
attr = getattr(father_class_node, 'attr')
|
||||
module_name = getattr(value, 'id')
|
||||
father_class_module = self.global_namespace[module_name]
|
||||
father_class = getattr(father_class_module, attr)
|
||||
if father_class is None:
|
||||
raise ValueError("When call 'super', the father class is None.")
|
||||
if not isinstance(subclass_instance, father_class):
|
||||
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
|
||||
|
||||
target_class_instance = super(father_class, subclass_instance)
|
||||
return target_class_instance
|
||||
|
||||
def get_location(self, node):
|
||||
"""
|
||||
Get location of node start and end line no.
|
||||
|
|
|
@ -117,6 +117,7 @@ convert_object_map = {
|
|||
T.zip: C.zip_operation,
|
||||
T.print: F.print_,
|
||||
T.enumerate: M.enumerate_,
|
||||
T.isinstance: M.isinstance_,
|
||||
|
||||
# custom define operation
|
||||
T.iter: M.ms_iter,
|
||||
|
|
|
@ -114,6 +114,12 @@ def enumerate_(x, start=0):
|
|||
return ret
|
||||
|
||||
|
||||
def isinstance_(x, base_type):
|
||||
"""Determine whether x is an instance of base_type."""
|
||||
x_type = F.typeof(x)
|
||||
return check_type_same(x_type, base_type)
|
||||
|
||||
|
||||
def while_cond(x):
|
||||
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
||||
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||
|
@ -123,6 +129,12 @@ def while_cond(x):
|
|||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_type_same(x_type, base_type):
|
||||
"""Check x_type is same as base_type."""
|
||||
return mstype.issubclass_(x_type, base_type)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tuple_or_list(x, op_name, arg_name):
|
||||
"""check whether x is list or tuple."""
|
||||
|
@ -141,6 +153,7 @@ def check_is_const_int(x, op_name, arg_name):
|
|||
return True
|
||||
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tensor_bool_cond(shp):
|
||||
"""check if tensor is a bool condition"""
|
||||
|
@ -148,6 +161,7 @@ def check_is_tensor_bool_cond(shp):
|
|||
return True
|
||||
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
|
||||
|
||||
|
||||
@constexpr
|
||||
def const_tensor_to_bool(x):
|
||||
"""convert bool tensor to bool condition"""
|
||||
|
@ -162,6 +176,7 @@ def const_tensor_to_bool(x):
|
|||
value = bool(x[0])
|
||||
return value
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
"""tensor as conditon, if is constant, return immediate bool value"""
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
|
|
|
@ -27,7 +27,7 @@ from operator import ( # noqa
|
|||
|
||||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate, isinstance
|
||||
)
|
||||
|
||||
# support functools
|
||||
|
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
|||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial', 'print', 'enumerate',
|
||||
'partial', 'print', 'enumerate', 'isinstance',
|
||||
'exp', 'log', 'sin', 'cos', 'tan']
|
||||
|
||||
|
||||
|
|
|
@ -370,6 +370,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
|
||||
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
|
||||
converted = env;
|
||||
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
|
||||
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
|
||||
} else if (py::hasattr(obj, "__parameter__")) {
|
||||
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
|
||||
ret = ConvertData(to_convert, &converted);
|
||||
|
|
|
@ -109,7 +109,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) {
|
|||
|
||||
// Make a resolve node for symbol string
|
||||
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
||||
if (value.compare(0, strlen("self."), "self.") == 0) {
|
||||
if (value.compare(0, strlen("self"), "self") == 0) {
|
||||
auto start = value.find_first_of('.') + 1;
|
||||
if (start >= value.size()) {
|
||||
MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
|
@ -504,14 +505,45 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v
|
|||
[](AnfNodePtr arg) -> AnfNodePtr { return arg; });
|
||||
return block->func_graph()->NewCNode(make_tuple_nodes);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
|
||||
py::object father_class;
|
||||
if (args.empty()) {
|
||||
father_class = py::none();
|
||||
} else if (args.size() == 2) {
|
||||
father_class = args[0];
|
||||
auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1])));
|
||||
if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
|
||||
MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'.";
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
|
||||
}
|
||||
py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
|
||||
py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>("namespace");
|
||||
return block->MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
// process function call, eg : f1(x, y) ...
|
||||
AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Call";
|
||||
// process function call
|
||||
py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
|
||||
py::list args = python_adapter::GetPyObjAttr(node, "args");
|
||||
|
||||
auto arg_type =
|
||||
AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
|
||||
if (arg_type == AST_SUB_TYPE_NAME) {
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
|
||||
if (name_id == "super") {
|
||||
return ParseSuper(block, args);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
|
||||
// function call arguments should be passed in as groups and unpacked later using unpack call
|
||||
py::list args = python_adapter::GetPyObjAttr(node, "args");
|
||||
std::vector<AnfNodePtr> packed_arguments;
|
||||
std::vector<AnfNodePtr> group_arguments;
|
||||
|
||||
|
|
|
@ -138,6 +138,8 @@ class Parser {
|
|||
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a function call
|
||||
AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process function 'super'
|
||||
AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
|
||||
// process the if expression
|
||||
AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process class type define
|
||||
|
|
|
@ -81,6 +81,7 @@ const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
|
|||
const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
|
||||
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
|
||||
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
|
||||
const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
|
||||
|
||||
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
|
||||
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
||||
|
|
|
@ -80,7 +80,7 @@ using SymbolPtr = std::shared_ptr<Symbol>;
|
|||
// PyObjectWrapper class wrappers resolved python object for further processing.
|
||||
class PyObjectWrapper : public Named {
|
||||
public:
|
||||
explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {}
|
||||
explicit PyObjectWrapper(const py::object &obj, const std::string &name = "Python object") : Named(name), obj_(obj) {}
|
||||
~PyObjectWrapper() override = default;
|
||||
MS_DECLARE_PARENT(PyObjectWrapper, Named);
|
||||
py::object obj() { return obj_; }
|
||||
|
@ -93,7 +93,7 @@ class PyObjectWrapper : public Named {
|
|||
// ClassObject class wrappers dataclass
|
||||
class ClassObject : public PyObjectWrapper {
|
||||
public:
|
||||
explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass")
|
||||
explicit ClassObject(const py::object &obj, const std::string &name = "Python dataclass")
|
||||
: PyObjectWrapper(obj, name) {}
|
||||
~ClassObject() override = default;
|
||||
MS_DECLARE_PARENT(ClassObject, PyObjectWrapper);
|
||||
|
@ -103,7 +103,7 @@ class ClassObject : public PyObjectWrapper {
|
|||
// ClassType class wrappers class name in python
|
||||
class ClassType : public PyObjectWrapper {
|
||||
public:
|
||||
explicit ClassType(const py::object &obj, const std::string name = "Python class type")
|
||||
explicit ClassType(const py::object &obj, const std::string &name = "Python class type")
|
||||
: PyObjectWrapper(obj, name) {}
|
||||
~ClassType() override = default;
|
||||
MS_DECLARE_PARENT(ClassType, PyObjectWrapper);
|
||||
|
|
|
@ -25,6 +25,7 @@ const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__";
|
|||
const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__";
|
||||
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
|
||||
const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
|
||||
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
|
||||
|
||||
// flag names
|
||||
const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
|
||||
|
|
|
@ -27,6 +27,7 @@ extern const char PYTHON_ENVINSTANCE_FLAG[];
|
|||
extern const char PYTHON_DTYPE_FLAG[];
|
||||
extern const char PYTHON_CELL_AS_LIST[];
|
||||
extern const char PYTHON_DATACLASS_FIELDS[];
|
||||
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
|
||||
|
||||
extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
|
||||
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
|
||||
|
|
|
@ -62,6 +62,7 @@ def _wrap_func(fn):
|
|||
Returns:
|
||||
Function, a new function with return suitable format data.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*arg, **kwargs):
|
||||
results = fn(*arg, **kwargs)
|
||||
|
@ -74,6 +75,7 @@ def _wrap_func(fn):
|
|||
if isinstance(data, list):
|
||||
return list(_convert_data(x) for x in data)
|
||||
return data
|
||||
|
||||
return _convert_data(results)
|
||||
|
||||
return wrapper
|
||||
|
@ -106,6 +108,7 @@ class _MindSporeFunction:
|
|||
obj (Object): If function is a method, obj is the owner of function,
|
||||
else, obj is none.
|
||||
"""
|
||||
|
||||
def __init__(self, fn, input_signature=None, obj=None):
|
||||
self.fn = fn
|
||||
self.save_graphs = context.get_context("save_graphs")
|
||||
|
@ -245,6 +248,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
>>> out = tensor_add_with_dec(x, y)
|
||||
>>> out = tensor_add_with_sig(x, y)
|
||||
"""
|
||||
|
||||
def wrap_mindspore(func):
|
||||
@wraps(func)
|
||||
def staging_specialize(*args):
|
||||
|
@ -275,6 +279,7 @@ def _generate_pip_args(obj, *args, method="construct"):
|
|||
obj.__parse_method__ = parse_method
|
||||
return args_names, args_list
|
||||
|
||||
|
||||
class _PynativeExecutor:
|
||||
"""
|
||||
An pynative executor used to compile/manage/run graph.
|
||||
|
@ -304,6 +309,7 @@ class _PynativeExecutor:
|
|||
def __call__(self, *args):
|
||||
return self._executor(args, "")
|
||||
|
||||
|
||||
class _Executor:
|
||||
"""
|
||||
An executor used to compile/manage/run graph.
|
||||
|
@ -532,6 +538,7 @@ class _Executor:
|
|||
return None
|
||||
return self._executor.fetch_info_for_quant_export(exec_id)
|
||||
|
||||
|
||||
_executor = _Executor()
|
||||
_pynative_exec = _PynativeExecutor()
|
||||
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test super"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
|
||||
class FatherNet(nn.Cell):
|
||||
def __init__(self, x):
|
||||
super(FatherNet, self).__init__(x)
|
||||
self.x = x
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.x * x
|
||||
|
||||
def test_father(self, x):
|
||||
return self.x + x
|
||||
|
||||
|
||||
class MatherNet(nn.Cell):
|
||||
def __init__(self, y):
|
||||
super(MatherNet, self).__init__()
|
||||
self.y = y
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.y * y
|
||||
|
||||
def test_mather(self, y):
|
||||
return self.y + y
|
||||
|
||||
|
||||
class SingleSubNet(FatherNet):
|
||||
def __init__(self, x, z):
|
||||
super(SingleSubNet, self).__init__(x)
|
||||
self.z = z
|
||||
|
||||
def construct(self, x, y):
|
||||
ret_father_construct = super().construct(x, y)
|
||||
ret_father_test = super(SingleSubNet, self).test_father(x)
|
||||
ret_father_x = super(SingleSubNet, self).x
|
||||
ret_sub_z = self.z
|
||||
|
||||
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
|
||||
|
||||
|
||||
class MulSubNet(FatherNet, MatherNet):
|
||||
def __init__(self, x, y, z):
|
||||
super(MulSubNet, self).__init__(x)
|
||||
super(FatherNet, self).__init__(y)
|
||||
self.z = z
|
||||
|
||||
def construct(self, x, y):
|
||||
ret_father_construct = super().construct(x, y)
|
||||
ret_father_test = super(MulSubNet, self).test_father(x)
|
||||
ret_father_x = super(MulSubNet, self).x
|
||||
ret_mather_construct = super(FatherNet, self).construct(x, y)
|
||||
ret_mather_test = super(FatherNet, self).test_mather(y)
|
||||
ret_mather_y = super(FatherNet, self).y
|
||||
ret_sub_z = self.z
|
||||
|
||||
return ret_father_construct, ret_father_test, ret_father_x, \
|
||||
ret_mather_construct, ret_mather_test, ret_mather_y, ret_sub_z
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, x):
|
||||
super(Net, self).__init__()
|
||||
self.x = x
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = super(Net, self).construct(x, y)
|
||||
return ret
|
||||
|
||||
|
||||
def test_single_super():
|
||||
single_net = SingleSubNet(2, 3)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
single_net(x, y)
|
||||
|
||||
|
||||
def test_mul_super():
|
||||
mul_net = MulSubNet(2, 3, 4)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
mul_net(x, y)
|
||||
|
||||
|
||||
def test_super_cell():
|
||||
net = Net(2)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
net(x, y)
|
||||
assert "Unsupported syntax 'Raise'" in str(er.value)
|
Loading…
Reference in New Issue