!26690 Add supoort resolving outer lambda function for ops.Partial.

Merge pull request !26690 from hezhenhao1/add_lambda
This commit is contained in:
i-robot 2021-11-26 09:21:34 +00:00 committed by Gitee
commit d1e4e674ab
3 changed files with 261 additions and 51 deletions

View File

@ -175,7 +175,19 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
FuncGraphPtr Parser::ParseFuncGraph() { FuncGraphPtr Parser::ParseFuncGraph() {
// Get ast FunctionDef node // Get ast FunctionDef node
py::object node = ast_->GetAstNode(); py::object node = ast_->GetAstNode();
FunctionBlockPtr fn_block = ParseFunction(node); constexpr char function_def_name[] = "FunctionDef";
constexpr char lambda_name[] = "Lambda";
FunctionBlockPtr fn_block = nullptr;
if (ast_->GetNodeType(node)->node_name() == function_def_name) {
fn_block = ParseDefFunction(node);
} else {
auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
<< ast_->GetNodeType(lambda_node)->node_name() << ".";
}
fn_block = ParseLambdaFunction(lambda_node);
}
if (errcode() != PARSE_SUCCESS) { if (errcode() != PARSE_SUCCESS) {
MS_LOG(ERROR) << "Parse function error, code is " << errcode(); MS_LOG(ERROR) << "Parse function error, code is " << errcode();
return nullptr; return nullptr;
@ -259,7 +271,7 @@ ScopePtr Parser::GetScopeForParseFunction() {
return scope; return scope;
} }
FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
ScopePtr scope = GetScopeForParseFunction(); ScopePtr scope = GetScopeForParseFunction();
// The node created in the parsefunction context, will inherit the scope created using scope_guard // The node created in the parsefunction context, will inherit the scope created using scope_guard
ScopeGuard scope_guard(scope); ScopeGuard scope_guard(scope);
@ -323,6 +335,33 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
return func_block; return func_block;
} }
FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
MS_EXCEPTION_IF_NULL(ast_);
ScopePtr scope = GetScopeForParseFunction();
ScopeGuard scope_guard(scope);
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
if (block != nullptr) {
func_block->AddPrevBlock(block);
} else {
func_graph_ = func_block->func_graph();
}
func_block->Mature();
auto current_fg = func_block->func_graph();
auto function_name = ast_->function_name();
MS_LOG(DEBUG) << "The function name is " << function_name;
current_fg->debug_info()->set_name(function_name);
GenerateArgsNodeForFunction(func_block, node);
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
current_fg->set_output(lambda_body_node);
GenerateArgsDefaultValueForFunction(func_block, node);
return func_block;
}
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) { FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
auto node_list = py::cast<py::list>(nodes); auto node_list = py::cast<py::list>(nodes);
size_t count = py::len(node_list); size_t count = py::len(node_list);
@ -909,7 +948,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
// Process a function def // Process a function def
FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast FunctionDef"; MS_LOG(DEBUG) << "Process ast FunctionDef";
FunctionBlockPtr function_block = ParseFunction(node, block); FunctionBlockPtr function_block = ParseDefFunction(node, block);
MS_EXCEPTION_IF_NULL(function_block); MS_EXCEPTION_IF_NULL(function_block);
// Get function name // Get function name
@ -923,26 +962,10 @@ FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const p
// Process a lambda expression . like lambda x,y: x + y // Process a lambda expression . like lambda x,y: x + y
AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Lambda"; MS_LOG(DEBUG) << "Process ast Lambda";
FunctionBlockPtr func_block = MakeFunctionBlock(*this); FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
func_block->AddPrevBlock(block); MS_EXCEPTION_IF_NULL(function_block);
func_block->Mature();
// Get lambda args auto block_fg = function_block->func_graph();
py::list args = ast_->GetArgs(node);
auto block_fg = func_block->func_graph();
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(block_fg);
para_node->debug_info()->set_name(arg_name);
block_fg->add_parameter(para_node);
func_block->WriteVariable(arg_name, para_node);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
}
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
block_fg->set_output(lambda_body_node);
ValueNodePtr const_graph = NewValueNode(block_fg); ValueNodePtr const_graph = NewValueNode(block_fg);
return const_graph; return const_graph;
} }

View File

@ -196,7 +196,9 @@ class Parser {
// Generate argument default value for ast function node // Generate argument default value for ast function node
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// Parse ast function node // Parse ast function node
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// Parse lambda function node
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// Parse ast statements // Parse ast statements
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
// Parse one ast statement node // Parse one ast statement node

View File

@ -22,7 +22,14 @@ from mindspore import nn, Tensor, context
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
def test_partial_pos_arg(): def test_partial_pos_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_pos_arg
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -35,13 +42,31 @@ def test_partial_pos_arg():
ret = f(y, z) ret = f(y, z)
return ret return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, x)
ret = f(y, z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z) for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_key_ward_arg(): def test_partial_key_ward_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -54,13 +79,31 @@ def test_partial_key_ward_arg():
ret = f(y=y, z=z) ret = f(y=y, z=z)
return ret return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, x=x)
ret = f(y=y, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z) for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_key_ward_arg_update(): def test_partial_key_ward_arg_update():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_update
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -73,14 +116,31 @@ def test_partial_key_ward_arg_update():
ret = f(y=y, z=z) ret = f(y=y, z=z)
return ret return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, x=x, y=y)
ret = f(y=y, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z) for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_key_ward_arg_and_pos_arg(): def test_partial_key_ward_arg_and_pos_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -93,14 +153,31 @@ def test_partial_key_ward_arg_and_pos_arg():
ret = f(2, z=z) ret = f(2, z=z)
return ret return ret
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self, x, y, z):
f = partial(self.show, y=y)
ret = f(2, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z) for net in [Net(), Net2()]:
net(x, y, z)
def test_partial_pos_arg_const(): def test_partial_pos_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_pos_arg_const
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -113,10 +190,27 @@ def test_partial_pos_arg_const():
ret = f(2, 3) ret = f(2, 3)
return ret return ret
net = Net() class Net2(nn.Cell):
assert net() == (1, 2, 3) def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, 1)
ret = f(2, 3)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_const(): def test_partial_key_ward_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_const
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -129,10 +223,27 @@ def test_partial_key_ward_arg_const():
ret = f(y=2, z=3) ret = f(y=2, z=3)
return ret return ret
net = Net() class Net2(nn.Cell):
assert net() == (1, 2, 3) def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, x=1)
ret = f(y=2, z=3)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_update_const(): def test_partial_key_ward_arg_update_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_update_const
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -145,11 +256,27 @@ def test_partial_key_ward_arg_update_const():
ret = f(y=3, z=4) ret = f(y=3, z=4)
return ret return ret
net = Net() class Net2(nn.Cell):
assert net() == (1, 3, 4) def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, x=1, y=2)
ret = f(y=3, z=4)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 3, 4)
def test_partial_key_ward_arg_and_pos_arg_const(): def test_partial_key_ward_arg_and_pos_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -162,11 +289,27 @@ def test_partial_key_ward_arg_and_pos_arg_const():
ret = f(1, z=3) ret = f(1, z=3)
return ret return ret
net = Net() class Net2(nn.Cell):
assert net() == (1, 2, 3) def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, y=2)
ret = f(1, z=3)
return ret
for net in [Net(), Net2()]:
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x(): def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -179,13 +322,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
ret = f(1, 2, 3) ret = f(1, 2, 3)
return ret return ret
net = Net() class Net2(nn.Cell):
with pytest.raises(TypeError) as ex: def __init__(self):
net() super(Net2, self).__init__()
assert "Multiply values for specific argument: x" in str(ex.value) self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, x=1)
ret = f(1, 2, 3)
return ret
for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: x" in str(ex.value)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y(): def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -198,13 +357,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
ret = f(1, 2, z=3) ret = f(1, 2, z=3)
return ret return ret
net = Net() class Net2(nn.Cell):
with pytest.raises(TypeError) as ex: def __init__(self):
net() super(Net2, self).__init__()
assert "Multiply values for specific argument: y" in str(ex.value) self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, y=2)
ret = f(1, 2, z=3)
return ret
for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: y" in str(ex.value)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z(): def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z
Expectation: the result match given one
"""
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -217,7 +392,17 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
ret = f(1, 2, 3) ret = f(1, 2, 3)
return ret return ret
net = Net() class Net2(nn.Cell):
with pytest.raises(TypeError) as ex: def __init__(self):
net() super(Net2, self).__init__()
assert "Multiply values for specific argument: z" in str(ex.value) self.show = lambda x, y, z: (x, y, z)
def construct(self):
f = partial(self.show, z=1)
ret = f(1, 2, 3)
return ret
for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: z" in str(ex.value)