forked from mindspore-Ecosystem/mindspore
!26690 Add supoort resolving outer lambda function for ops.Partial.
Merge pull request !26690 from hezhenhao1/add_lambda
This commit is contained in:
commit
d1e4e674ab
|
@ -175,7 +175,19 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
|
|||
FuncGraphPtr Parser::ParseFuncGraph() {
|
||||
// Get ast FunctionDef node
|
||||
py::object node = ast_->GetAstNode();
|
||||
FunctionBlockPtr fn_block = ParseFunction(node);
|
||||
constexpr char function_def_name[] = "FunctionDef";
|
||||
constexpr char lambda_name[] = "Lambda";
|
||||
FunctionBlockPtr fn_block = nullptr;
|
||||
if (ast_->GetNodeType(node)->node_name() == function_def_name) {
|
||||
fn_block = ParseDefFunction(node);
|
||||
} else {
|
||||
auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
|
||||
if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
|
||||
MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
|
||||
<< ast_->GetNodeType(lambda_node)->node_name() << ".";
|
||||
}
|
||||
fn_block = ParseLambdaFunction(lambda_node);
|
||||
}
|
||||
if (errcode() != PARSE_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
|
||||
return nullptr;
|
||||
|
@ -259,7 +271,7 @@ ScopePtr Parser::GetScopeForParseFunction() {
|
|||
return scope;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
|
||||
FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
|
||||
ScopePtr scope = GetScopeForParseFunction();
|
||||
// The node created in the parsefunction context, will inherit the scope created using scope_guard
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
@ -323,6 +335,33 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
return func_block;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
|
||||
MS_EXCEPTION_IF_NULL(ast_);
|
||||
ScopePtr scope = GetScopeForParseFunction();
|
||||
ScopeGuard scope_guard(scope);
|
||||
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
|
||||
|
||||
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
|
||||
if (block != nullptr) {
|
||||
func_block->AddPrevBlock(block);
|
||||
} else {
|
||||
func_graph_ = func_block->func_graph();
|
||||
}
|
||||
func_block->Mature();
|
||||
auto current_fg = func_block->func_graph();
|
||||
|
||||
auto function_name = ast_->function_name();
|
||||
MS_LOG(DEBUG) << "The function name is " << function_name;
|
||||
current_fg->debug_info()->set_name(function_name);
|
||||
GenerateArgsNodeForFunction(func_block, node);
|
||||
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
|
||||
current_fg->set_output(lambda_body_node);
|
||||
GenerateArgsDefaultValueForFunction(func_block, node);
|
||||
return func_block;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
|
||||
auto node_list = py::cast<py::list>(nodes);
|
||||
size_t count = py::len(node_list);
|
||||
|
@ -909,7 +948,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
|
|||
// Process a function def
|
||||
FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast FunctionDef";
|
||||
FunctionBlockPtr function_block = ParseFunction(node, block);
|
||||
FunctionBlockPtr function_block = ParseDefFunction(node, block);
|
||||
MS_EXCEPTION_IF_NULL(function_block);
|
||||
|
||||
// Get function name
|
||||
|
@ -923,26 +962,10 @@ FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const p
|
|||
// Process a lambda expression . like lambda x,y: x + y
|
||||
AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Lambda";
|
||||
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
|
||||
func_block->AddPrevBlock(block);
|
||||
func_block->Mature();
|
||||
FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
|
||||
MS_EXCEPTION_IF_NULL(function_block);
|
||||
|
||||
// Get lambda args
|
||||
py::list args = ast_->GetArgs(node);
|
||||
auto block_fg = func_block->func_graph();
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
||||
TraceGuard guard(GetLocation(args[i]));
|
||||
auto para_node = std::make_shared<Parameter>(block_fg);
|
||||
para_node->debug_info()->set_name(arg_name);
|
||||
block_fg->add_parameter(para_node);
|
||||
func_block->WriteVariable(arg_name, para_node);
|
||||
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
|
||||
}
|
||||
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
|
||||
block_fg->set_output(lambda_body_node);
|
||||
auto block_fg = function_block->func_graph();
|
||||
ValueNodePtr const_graph = NewValueNode(block_fg);
|
||||
return const_graph;
|
||||
}
|
||||
|
|
|
@ -196,7 +196,9 @@ class Parser {
|
|||
// Generate argument default value for ast function node
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// Parse ast function node
|
||||
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
// Parse lambda function node
|
||||
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
// Parse ast statements
|
||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
|
||||
// Parse one ast statement node
|
||||
|
|
|
@ -22,7 +22,14 @@ from mindspore import nn, Tensor, context
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_partial_pos_arg():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_pos_arg
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -35,13 +42,31 @@ def test_partial_pos_arg():
|
|||
ret = f(y, z)
|
||||
return ret
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, x)
|
||||
ret = f(y, z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -54,13 +79,31 @@ def test_partial_key_ward_arg():
|
|||
ret = f(y=y, z=z)
|
||||
return ret
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, x=x)
|
||||
ret = f(y=y, z=z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_update():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_update
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -73,14 +116,31 @@ def test_partial_key_ward_arg_update():
|
|||
ret = f(y=y, z=z)
|
||||
return ret
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, x=x, y=y)
|
||||
ret = f(y=y, z=z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_and_pos_arg
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -93,14 +153,31 @@ def test_partial_key_ward_arg_and_pos_arg():
|
|||
ret = f(2, z=z)
|
||||
return ret
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, y=y)
|
||||
ret = f(2, z=z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_partial_pos_arg_const():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_pos_arg_const
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -113,10 +190,27 @@ def test_partial_pos_arg_const():
|
|||
ret = f(2, 3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, 1)
|
||||
ret = f(2, 3)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
assert net() == (1, 2, 3)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_const():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_const
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -129,10 +223,27 @@ def test_partial_key_ward_arg_const():
|
|||
ret = f(y=2, z=3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, x=1)
|
||||
ret = f(y=2, z=3)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
assert net() == (1, 2, 3)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_update_const():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_update_const
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -145,11 +256,27 @@ def test_partial_key_ward_arg_update_const():
|
|||
ret = f(y=3, z=4)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 3, 4)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, x=1, y=2)
|
||||
ret = f(y=3, z=4)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
assert net() == (1, 3, 4)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_and_pos_arg_const
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -162,11 +289,27 @@ def test_partial_key_ward_arg_and_pos_arg_const():
|
|||
ret = f(1, z=3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, y=2)
|
||||
ret = f(1, z=3)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
assert net() == (1, 2, 3)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -179,13 +322,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
|
|||
ret = f(1, 2, 3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: x" in str(ex.value)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, x=1)
|
||||
ret = f(1, 2, 3)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: x" in str(ex.value)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -198,13 +357,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
|
|||
ret = f(1, 2, z=3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: y" in str(ex.value)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, y=2)
|
||||
ret = f(1, 2, z=3)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: y" in str(ex.value)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z
|
||||
Expectation: the result match given one
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -217,7 +392,17 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
|
|||
ret = f(1, 2, 3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: z" in str(ex.value)
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.show = lambda x, y, z: (x, y, z)
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, z=1)
|
||||
ret = f(1, 2, 3)
|
||||
return ret
|
||||
|
||||
for net in [Net(), Net2()]:
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: z" in str(ex.value)
|
||||
|
|
Loading…
Reference in New Issue