forked from mindspore-Ecosystem/mindspore
!28384 [Fallback] Support the use of self
Merge pull request !28384 from huangbingjian/fallback_self
This commit is contained in:
commit
dba67422f9
|
@ -188,6 +188,14 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
|
|||
AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
|
||||
auto ast = parser_.ast();
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() != "0");
|
||||
if (use_fallback && !global_py_params().contains("self")) {
|
||||
py::object self_namespace = ast->CallParseModFunction(PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL, ast->obj());
|
||||
AddGlobalPyParam("self", self_namespace);
|
||||
}
|
||||
|
||||
py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(attr);
|
||||
|
@ -261,7 +269,8 @@ AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
|
|||
// Make a resolve node for symbol string
|
||||
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
||||
MS_LOG(DEBUG) << "value: " << value;
|
||||
if (value.compare(0, strlen("self"), "self") == 0) {
|
||||
// The prefix of value is "self.".
|
||||
if (value.compare(0, strlen("self."), "self.") == 0) {
|
||||
auto start = value.find_first_of('.') + 1;
|
||||
if (start >= value.size()) {
|
||||
MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value;
|
||||
|
|
|
@ -68,6 +68,7 @@ const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create
|
|||
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
||||
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
|
||||
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
||||
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
|
||||
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
|
||||
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
|
|
|
@ -23,7 +23,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
|
||||
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
|
||||
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
|
||||
get_object_description)
|
||||
get_object_description, get_class_attr_namespace_symbol)
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
|
||||
|
@ -32,4 +32,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
|
||||
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
|
||||
'generate_scope', 'get_operation_symbol']
|
||||
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol']
|
||||
|
|
|
@ -120,3 +120,21 @@ class ClassMemberNamespace(Namespace):
|
|||
except KeyError:
|
||||
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
||||
raise AttributeError(name)
|
||||
|
||||
|
||||
class ClassAttrNamespace(Namespace):
|
||||
"""
|
||||
Namespace of a class.
|
||||
|
||||
Args:
|
||||
obj (Object): A python class object.
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
|
||||
super().__init__(name, obj)
|
||||
|
||||
def __getattr__(self, name):
|
||||
for d in self.dicts:
|
||||
if hasattr(d, name):
|
||||
return getattr(d, name)
|
||||
raise NameError(name)
|
||||
|
|
|
@ -32,7 +32,7 @@ from mindspore import nn
|
|||
from mindspore import ops
|
||||
from mindspore.common.api import _MindsporeFunctionExecutor
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
|
||||
# define return value
|
||||
|
@ -377,6 +377,14 @@ def get_module_namespace(obj):
|
|||
return mod_namespace
|
||||
|
||||
|
||||
def get_class_attr_namespace_symbol(obj):
|
||||
"""Get class namespace."""
|
||||
logger.debug("get class namespace, object: %r", obj)
|
||||
class_namespace = ClassAttrNamespace(obj)
|
||||
logger.debug("class namespace: %r", class_namespace)
|
||||
return class_namespace
|
||||
|
||||
|
||||
def get_class_member_namespace_symbol(obj):
|
||||
"""Get obj class member type."""
|
||||
logger.debug("get class instance namespace, object: %r", obj)
|
||||
|
|
|
@ -223,3 +223,114 @@ def test_context():
|
|||
net = ContextNet()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.dim = 1
|
||||
|
||||
def construct(self, x):
|
||||
batch = x.shape[0]
|
||||
one = Tensor(np.ones([batch, self.dim]), mstype.float16)
|
||||
return one * x
|
||||
|
||||
net = Network()
|
||||
x = Tensor([1, 2], mstype.float32)
|
||||
out = net(x)
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_attr_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self, fn):
|
||||
super(Network, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
net = Network(fn)
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_attr_3():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.value = [2, 2, 3]
|
||||
|
||||
def construct(self):
|
||||
x = np.array(self.value.count(2))
|
||||
return Tensor(x)
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_self_method_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
z = self.fn(x, y)
|
||||
out = Tensor(z)
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
|
Loading…
Reference in New Issue