!26690 Add supoort resolving outer lambda function for ops.Partial.

Merge pull request !26690 from hezhenhao1/add_lambda
This commit is contained in:
i-robot 2021-11-26 09:21:34 +00:00 committed by Gitee
commit d1e4e674ab
3 changed files with 261 additions and 51 deletions

View File

@ -175,7 +175,19 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
FuncGraphPtr Parser::ParseFuncGraph() {
// Get ast FunctionDef node
py::object node = ast_->GetAstNode();
FunctionBlockPtr fn_block = ParseFunction(node);
constexpr char function_def_name[] = "FunctionDef";
constexpr char lambda_name[] = "Lambda";
FunctionBlockPtr fn_block = nullptr;
if (ast_->GetNodeType(node)->node_name() == function_def_name) {
fn_block = ParseDefFunction(node);
} else {
auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
<< ast_->GetNodeType(lambda_node)->node_name() << ".";
}
fn_block = ParseLambdaFunction(lambda_node);
}
if (errcode() != PARSE_SUCCESS) {
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
return nullptr;
@ -259,7 +271,7 @@ ScopePtr Parser::GetScopeForParseFunction() {
return scope;
}
FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
ScopePtr scope = GetScopeForParseFunction();
// The node created in the parsefunction context, will inherit the scope created using scope_guard
ScopeGuard scope_guard(scope);
@ -323,6 +335,33 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
return func_block;
}
FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
MS_EXCEPTION_IF_NULL(ast_);
ScopePtr scope = GetScopeForParseFunction();
ScopeGuard scope_guard(scope);
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
if (block != nullptr) {
func_block->AddPrevBlock(block);
} else {
func_graph_ = func_block->func_graph();
}
func_block->Mature();
auto current_fg = func_block->func_graph();
auto function_name = ast_->function_name();
MS_LOG(DEBUG) << "The function name is " << function_name;
current_fg->debug_info()->set_name(function_name);
GenerateArgsNodeForFunction(func_block, node);
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
current_fg->set_output(lambda_body_node);
GenerateArgsDefaultValueForFunction(func_block, node);
return func_block;
}
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
auto node_list = py::cast<py::list>(nodes);
size_t count = py::len(node_list);
@ -909,7 +948,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
// Process a function def
FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast FunctionDef";
FunctionBlockPtr function_block = ParseFunction(node, block);
FunctionBlockPtr function_block = ParseDefFunction(node, block);
MS_EXCEPTION_IF_NULL(function_block);
// Get function name
@ -923,26 +962,10 @@ FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const p
// Process a lambda expression . like lambda x,y: x + y
AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Lambda";
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
func_block->AddPrevBlock(block);
func_block->Mature();
FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
MS_EXCEPTION_IF_NULL(function_block);
// Get lambda args
py::list args = ast_->GetArgs(node);
auto block_fg = func_block->func_graph();
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(block_fg);
para_node->debug_info()->set_name(arg_name);
block_fg->add_parameter(para_node);
func_block->WriteVariable(arg_name, para_node);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
}
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
block_fg->set_output(lambda_body_node);
auto block_fg = function_block->func_graph();
ValueNodePtr const_graph = NewValueNode(block_fg);
return const_graph;
}

View File

@ -196,7 +196,9 @@ class Parser {
// Generate argument default value for ast function node
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// Parse ast function node
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// Parse lambda function node
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// Parse ast statements
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
// Parse one ast statement node

View File

@ -22,7 +22,14 @@ from mindspore import nn, Tensor, context
context.set_context(mode=context.GRAPH_MODE)
def test_partial_pos_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_pos_arg
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -35,13 +42,31 @@ def test_partial_pos_arg():
ret = f(y, z)
return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, x)
ret = f(y, z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_key_ward_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -54,13 +79,31 @@ def test_partial_key_ward_arg():
ret = f(y=y, z=z)
return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, x=x)
ret = f(y=y, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_key_ward_arg_update():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_update
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -73,14 +116,31 @@ def test_partial_key_ward_arg_update():
ret = f(y=y, z=z)
return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, x=x, y=y)
ret = f(y=y, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_key_ward_arg_and_pos_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -93,14 +153,31 @@ def test_partial_key_ward_arg_and_pos_arg():
ret = f(2, z=z)
return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, y=y)
ret = f(2, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_pos_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_pos_arg_const
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -113,10 +190,27 @@ def test_partial_pos_arg_const():
ret = f(2, 3)
return ret
net = Net()
assert net() == (1, 2, 3)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, 1)
ret = f(2, 3)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_const
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -129,10 +223,27 @@ def test_partial_key_ward_arg_const():
ret = f(y=2, z=3)
return ret
net = Net()
assert net() == (1, 2, 3)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, x=1)
ret = f(y=2, z=3)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_update_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_update_const
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -145,11 +256,27 @@ def test_partial_key_ward_arg_update_const():
ret = f(y=3, z=4)
return ret
net = Net()
assert net() == (1, 3, 4)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, x=1, y=2)
ret = f(y=3, z=4)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 3, 4)
def test_partial_key_ward_arg_and_pos_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -162,11 +289,27 @@ def test_partial_key_ward_arg_and_pos_arg_const():
ret = f(1, z=3)
return ret
net = Net()
assert net() == (1, 2, 3)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, y=2)
ret = f(1, z=3)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -179,13 +322,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
ret = f(1, 2, 3)
return ret
net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: x" in str(ex.value)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, x=1)
ret = f(1, 2, 3)
return ret
for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: x" in str(ex.value)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -198,13 +357,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
ret = f(1, 2, z=3)
return ret
net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: y" in str(ex.value)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, y=2)
ret = f(1, 2, z=3)
return ret
for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: y" in str(ex.value)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z
Expectation: the result match given one
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -217,7 +392,17 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
ret = f(1, 2, 3)
return ret
net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: z" in str(ex.value)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, z=1)
ret = f(1, 2, 3)
return ret
for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: z" in str(ex.value)