!28384 [Fallback] Support the use of self

Merge pull request !28384 from huangbingjian/fallback_self
This commit is contained in:
i-robot 2021-12-31 09:26:03 +00:00 committed by Gitee
commit dba67422f9
6 changed files with 151 additions and 4 deletions

View File

@ -188,6 +188,14 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
auto ast = parser_.ast();
MS_EXCEPTION_IF_NULL(ast);
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() != "0");
if (use_fallback && !global_py_params().contains("self")) {
py::object self_namespace = ast->CallParseModFunction(PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL, ast->obj());
AddGlobalPyParam("self", self_namespace);
}
py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
SymbolPtr symbol = std::make_shared<Symbol>(attr);
@ -261,7 +269,8 @@ AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
// Make a resolve node for symbol string
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
MS_LOG(DEBUG) << "value: " << value;
if (value.compare(0, strlen("self"), "self") == 0) {
// The prefix of value is "self.".
if (value.compare(0, strlen("self."), "self.") == 0) {
auto start = value.find_first_of('.') + 1;
if (start >= value.size()) {
MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;

View File

@ -68,6 +68,7 @@ const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";

View File

@ -23,7 +23,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
get_object_description)
get_object_description, get_class_attr_namespace_symbol)
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
@ -32,4 +32,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
'generate_scope', 'get_operation_symbol']
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol']

View File

@ -120,3 +120,21 @@ class ClassMemberNamespace(Namespace):
except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)
class ClassAttrNamespace(Namespace):
"""
Namespace of a class.
Args:
obj (Object): A python class object.
"""
def __init__(self, obj):
name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
super().__init__(name, obj)
def __getattr__(self, name):
for d in self.dicts:
if hasattr(d, name):
return getattr(d, name)
raise NameError(name)

View File

@ -32,7 +32,7 @@ from mindspore import nn
from mindspore import ops
from mindspore.common.api import _MindsporeFunctionExecutor
from mindspore.common.dtype import pytype_to_dtype
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
# define return value
@ -377,6 +377,14 @@ def get_module_namespace(obj):
return mod_namespace
def get_class_attr_namespace_symbol(obj):
"""Get class namespace."""
logger.debug("get class namespace, object: %r", obj)
class_namespace = ClassAttrNamespace(obj)
logger.debug("class namespace: %r", class_namespace)
return class_namespace
def get_class_member_namespace_symbol(obj):
"""Get obj class member type."""
logger.debug("get class instance namespace, object: %r", obj)

View File

@ -223,3 +223,114 @@ def test_context():
net = ContextNet()
out = net()
print(out)
def test_self_attr():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.dim = 1
def construct(self, x):
batch = x.shape[0]
one = Tensor(np.ones([batch, self.dim]), mstype.float16)
return one * x
net = Network()
x = Tensor([1, 2], mstype.float32)
out = net(x)
print(out)
def test_self_attr_2():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self, fn):
super(Network, self).__init__()
self.fn = fn
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(x, y):
return x + y
net = Network(fn)
out = net()
print(out)
def test_self_attr_3():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.value = [2, 2, 3]
def construct(self):
x = np.array(self.value.count(2))
return Tensor(x)
net = Network()
out = net()
print(out)
def test_self_method():
"""
Feature: JIT Fallback
Description: Test self.method in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
print(out)
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_self_method_2():
"""
Feature: JIT Fallback
Description: Test self.method in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
z = self.fn(x, y)
out = Tensor(z)
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
print(out)