!3237 support call the parent class function

Merge pull request !3237 from zhangbuxue/support_call_the_parent_class_construct_function
This commit is contained in:
mindspore-ci-bot 2020-07-21 22:13:04 +08:00 committed by Gitee
commit 8da91ca3cf
15 changed files with 222 additions and 9 deletions

View File

@ -99,12 +99,19 @@ class ClassMemberNamespace(Namespace):
obj (Object): A python class object.
"""
def __init__(self, obj):
self.__class_member_namespace__ = True
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
super().__init__(label, obj)
def __getitem__(self, name):
d, = self.dicts
if name == "self":
return d
if name == "namespace":
return self
try:
return getattr(d, name)
if hasattr(d, name):
return getattr(d, name)
return d.__dict__[name]
except ValueError:
raise UnboundLocalError(name)

View File

@ -70,6 +70,7 @@ parse_expr_statement_white_list = (
"append",
)
def create_slice_obj(start, end, step):
"""Create slice object"""
return slice(start, end, step)
@ -201,9 +202,10 @@ def get_object_key(obj):
if isinstance(obj, types.MethodType):
method_instance = obj.__self__
instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
obj_id = instance_id + obj_id
obj_id = instance_id + obj_id + str(obj.__hash__())
return obj_id, obj_key
def get_default_input(obj):
if hasattr(obj, '__parameter__'):
return obj.default_input
@ -213,6 +215,7 @@ def get_default_input(obj):
return args
return obj
def is_class_member(node):
"""Check the attr is class member variable."""
type_ = node.__class__.__name__
@ -224,10 +227,12 @@ def is_class_member(node):
return True
return False
def get_obj_id(obj):
"""Get the obj id."""
return str(id(obj))
def get_obj_type(obj):
"""Get the obj type."""
obj_type = RESOLVE_TYPE_INVALID
@ -320,6 +325,7 @@ def get_dataclass_methods(cls):
if isinstance(getattr(cls, name), (types.FunctionType,))}
return methods
class Parser:
"""
Parser python code to ast tree.
@ -453,6 +459,28 @@ class Parser:
logger.debug("ops info = %r", ops_info)
return ops_info
def analyze_super(self, father_class_node, subclass_instance):
"""Analyze super and return a class instance."""
father_class = None
if father_class_node is None:
father_class = type(subclass_instance)
if isinstance(father_class_node, ast.Name):
father_class_name = getattr(father_class_node, 'id')
father_class = self.global_namespace[father_class_name]
if isinstance(father_class_node, ast.Attribute):
value = getattr(father_class_node, 'value')
attr = getattr(father_class_node, 'attr')
module_name = getattr(value, 'id')
father_class_module = self.global_namespace[module_name]
father_class = getattr(father_class_module, attr)
if father_class is None:
raise ValueError("When call 'super', the father class is None.")
if not isinstance(subclass_instance, father_class):
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
target_class_instance = super(father_class, subclass_instance)
return target_class_instance
def get_location(self, node):
"""
Get location of node start and end line no.

View File

@ -117,6 +117,7 @@ convert_object_map = {
T.zip: C.zip_operation,
T.print: F.print_,
T.enumerate: M.enumerate_,
T.isinstance: M.isinstance_,
# custom define operation
T.iter: M.ms_iter,

View File

@ -114,6 +114,12 @@ def enumerate_(x, start=0):
return ret
def isinstance_(x, base_type):
"""Determine whether x is an instance of base_type."""
x_type = F.typeof(x)
return check_type_same(x_type, base_type)
def while_cond(x):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
@ -123,6 +129,12 @@ def while_cond(x):
return x
@constexpr
def check_type_same(x_type, base_type):
"""Check x_type is same as base_type."""
return mstype.issubclass_(x_type, base_type)
@constexpr
def check_is_tuple_or_list(x, op_name, arg_name):
"""check whether x is list or tuple."""
@ -141,6 +153,7 @@ def check_is_const_int(x, op_name, arg_name):
return True
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
@ -148,6 +161,7 @@ def check_is_tensor_bool_cond(shp):
return True
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
@constexpr
def const_tensor_to_bool(x):
"""convert bool tensor to bool condition"""
@ -162,6 +176,7 @@ def const_tensor_to_bool(x):
value = bool(x[0])
return value
def tensor_bool(x):
"""tensor as conditon, if is constant, return immediate bool value"""
is_cond = check_is_tensor_bool_cond(F.shape(x))

View File

@ -27,7 +27,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print, enumerate, isinstance
)
# support functools
@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'print', 'enumerate',
'partial', 'print', 'enumerate', 'isinstance',
'exp', 'log', 'sin', 'cos', 'tan']

View File

@ -370,6 +370,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env;
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
} else if (py::hasattr(obj, "__parameter__")) {
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ret = ConvertData(to_convert, &converted);

View File

@ -109,7 +109,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) {
// Make a resolve node for symbol string
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
if (value.compare(0, strlen("self."), "self.") == 0) {
if (value.compare(0, strlen("self"), "self") == 0) {
auto start = value.find_first_of('.') + 1;
if (start >= value.size()) {
MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;

View File

@ -22,6 +22,7 @@
#include <sstream>
#include <unordered_map>
#include <algorithm>
#include "pipeline/jit/parse/resolve.h"
#include "frontend/operator/ops.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/operator/composite/composite.h"
@ -504,14 +505,45 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v
[](AnfNodePtr arg) -> AnfNodePtr { return arg; });
return block->func_graph()->NewCNode(make_tuple_nodes);
}
AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
py::object father_class;
if (args.empty()) {
father_class = py::none();
} else if (args.size() == 2) {
father_class = args[0];
auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1])));
if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'.";
}
} else {
MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
}
py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
SymbolPtr symbol = std::make_shared<Symbol>("namespace");
return block->MakeResolve(name_space, symbol);
}
// process function call, eg : f1(x, y) ...
AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Call";
// process function call
py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
py::list args = python_adapter::GetPyObjAttr(node, "args");
auto arg_type =
AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
if (arg_type == AST_SUB_TYPE_NAME) {
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
if (name_id == "super") {
return ParseSuper(block, args);
}
}
AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
// function call arguments should be passed in as groups and unpacked later using unpack call
py::list args = python_adapter::GetPyObjAttr(node, "args");
std::vector<AnfNodePtr> packed_arguments;
std::vector<AnfNodePtr> group_arguments;

View File

@ -138,6 +138,8 @@ class Parser {
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
// process a function call
AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
// process function 'super'
AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
// process the if expression
AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
// process class type define

View File

@ -81,6 +81,7 @@ const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";

View File

@ -80,7 +80,7 @@ using SymbolPtr = std::shared_ptr<Symbol>;
// PyObjectWrapper class wrappers resolved python object for further processing.
class PyObjectWrapper : public Named {
public:
explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {}
explicit PyObjectWrapper(const py::object &obj, const std::string &name = "Python object") : Named(name), obj_(obj) {}
~PyObjectWrapper() override = default;
MS_DECLARE_PARENT(PyObjectWrapper, Named);
py::object obj() { return obj_; }
@ -93,7 +93,7 @@ class PyObjectWrapper : public Named {
// ClassObject class wrappers dataclass
class ClassObject : public PyObjectWrapper {
public:
explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass")
explicit ClassObject(const py::object &obj, const std::string &name = "Python dataclass")
: PyObjectWrapper(obj, name) {}
~ClassObject() override = default;
MS_DECLARE_PARENT(ClassObject, PyObjectWrapper);
@ -103,7 +103,7 @@ class ClassObject : public PyObjectWrapper {
// ClassType class wrappers class name in python
class ClassType : public PyObjectWrapper {
public:
explicit ClassType(const py::object &obj, const std::string name = "Python class type")
explicit ClassType(const py::object &obj, const std::string &name = "Python class type")
: PyObjectWrapper(obj, name) {}
~ClassType() override = default;
MS_DECLARE_PARENT(ClassType, PyObjectWrapper);

View File

@ -25,6 +25,7 @@ const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__";
const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__";
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
// flag names
const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";

View File

@ -27,6 +27,7 @@ extern const char PYTHON_ENVINSTANCE_FLAG[];
extern const char PYTHON_DTYPE_FLAG[];
extern const char PYTHON_CELL_AS_LIST[];
extern const char PYTHON_DATACLASS_FIELDS[];
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];

View File

@ -62,6 +62,7 @@ def _wrap_func(fn):
Returns:
Function, a new function with return suitable format data.
"""
@wraps(fn)
def wrapper(*arg, **kwargs):
results = fn(*arg, **kwargs)
@ -74,6 +75,7 @@ def _wrap_func(fn):
if isinstance(data, list):
return list(_convert_data(x) for x in data)
return data
return _convert_data(results)
return wrapper
@ -106,6 +108,7 @@ class _MindSporeFunction:
obj (Object): If function is a method, obj is the owner of function,
else, obj is none.
"""
def __init__(self, fn, input_signature=None, obj=None):
self.fn = fn
self.save_graphs = context.get_context("save_graphs")
@ -245,6 +248,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
>>> out = tensor_add_with_dec(x, y)
>>> out = tensor_add_with_sig(x, y)
"""
def wrap_mindspore(func):
@wraps(func)
def staging_specialize(*args):
@ -275,6 +279,7 @@ def _generate_pip_args(obj, *args, method="construct"):
obj.__parse_method__ = parse_method
return args_names, args_list
class _PynativeExecutor:
"""
An pynative executor used to compile/manage/run graph.
@ -304,6 +309,7 @@ class _PynativeExecutor:
def __call__(self, *args):
return self._executor(args, "")
class _Executor:
"""
An executor used to compile/manage/run graph.
@ -532,6 +538,7 @@ class _Executor:
return None
return self._executor.fetch_info_for_quant_export(exec_id)
_executor = _Executor()
_pynative_exec = _PynativeExecutor()

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test super"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class FatherNet(nn.Cell):
def __init__(self, x):
super(FatherNet, self).__init__(x)
self.x = x
def construct(self, x, y):
return self.x * x
def test_father(self, x):
return self.x + x
class MatherNet(nn.Cell):
def __init__(self, y):
super(MatherNet, self).__init__()
self.y = y
def construct(self, x, y):
return self.y * y
def test_mather(self, y):
return self.y + y
class SingleSubNet(FatherNet):
def __init__(self, x, z):
super(SingleSubNet, self).__init__(x)
self.z = z
def construct(self, x, y):
ret_father_construct = super().construct(x, y)
ret_father_test = super(SingleSubNet, self).test_father(x)
ret_father_x = super(SingleSubNet, self).x
ret_sub_z = self.z
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
class MulSubNet(FatherNet, MatherNet):
def __init__(self, x, y, z):
super(MulSubNet, self).__init__(x)
super(FatherNet, self).__init__(y)
self.z = z
def construct(self, x, y):
ret_father_construct = super().construct(x, y)
ret_father_test = super(MulSubNet, self).test_father(x)
ret_father_x = super(MulSubNet, self).x
ret_mather_construct = super(FatherNet, self).construct(x, y)
ret_mather_test = super(FatherNet, self).test_mather(y)
ret_mather_y = super(FatherNet, self).y
ret_sub_z = self.z
return ret_father_construct, ret_father_test, ret_father_x, \
ret_mather_construct, ret_mather_test, ret_mather_y, ret_sub_z
class Net(nn.Cell):
def __init__(self, x):
super(Net, self).__init__()
self.x = x
def construct(self, x, y):
ret = super(Net, self).construct(x, y)
return ret
def test_single_super():
single_net = SingleSubNet(2, 3)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
single_net(x, y)
def test_mul_super():
mul_net = MulSubNet(2, 3, 4)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
mul_net(x, y)
def test_super_cell():
net = Net(2)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er:
net(x, y)
assert "Unsupported syntax 'Raise'" in str(er.value)