From 4af312d17e11859a3f1cd9f6d5b6047fa0d2ab83 Mon Sep 17 00:00:00 2001 From: hezhenhao1 Date: Tue, 23 Nov 2021 20:01:50 +0800 Subject: [PATCH] Add supoort resolving outer lambda function for ops.Partial. --- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 67 +++-- mindspore/ccsrc/pipeline/jit/parse/parse.h | 4 +- .../ut/python/pipeline/parse/test_partial.py | 241 ++++++++++++++++-- 3 files changed, 261 insertions(+), 51 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index a777a98440e..e37d3630c2b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -175,7 +175,19 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptrGetAstNode(); - FunctionBlockPtr fn_block = ParseFunction(node); + constexpr char function_def_name[] = "FunctionDef"; + constexpr char lambda_name[] = "Lambda"; + FunctionBlockPtr fn_block = nullptr; + if (ast_->GetNodeType(node)->node_name() == function_def_name) { + fn_block = ParseDefFunction(node); + } else { + auto lambda_node = python_adapter::GetPyObjAttr(node, "value"); + if (py::isinstance(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) { + MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got " + << ast_->GetNodeType(lambda_node)->node_name() << "."; + } + fn_block = ParseLambdaFunction(lambda_node); + } if (errcode() != PARSE_SUCCESS) { MS_LOG(ERROR) << "Parse function error, code is " << errcode(); return nullptr; @@ -259,7 +271,7 @@ ScopePtr Parser::GetScopeForParseFunction() { return scope; } -FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { +FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) { ScopePtr scope = GetScopeForParseFunction(); // The node created in the parsefunction context, will inherit the scope created using scope_guard ScopeGuard scope_guard(scope); @@ -323,6 +335,33 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo return func_block; } +FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) { + MS_EXCEPTION_IF_NULL(ast_); + ScopePtr scope = GetScopeForParseFunction(); + ScopeGuard scope_guard(scope); + TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node)); + + FunctionBlockPtr func_block = MakeFunctionBlock(*this); + if (block != nullptr) { + func_block->AddPrevBlock(block); + } else { + func_graph_ = func_block->func_graph(); + } + func_block->Mature(); + auto current_fg = func_block->func_graph(); + + auto function_name = ast_->function_name(); + MS_LOG(DEBUG) << "The function name is " << function_name; + current_fg->debug_info()->set_name(function_name); + GenerateArgsNodeForFunction(func_block, node); + + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); + current_fg->set_output(lambda_body_node); + GenerateArgsDefaultValueForFunction(func_block, node); + return func_block; +} + FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) { auto node_list = py::cast(nodes); size_t count = py::len(node_list); @@ -919,7 +958,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object & // Process a function def FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast FunctionDef"; - FunctionBlockPtr function_block = ParseFunction(node, block); + FunctionBlockPtr function_block = ParseDefFunction(node, block); MS_EXCEPTION_IF_NULL(function_block); // Get function name @@ -933,26 +972,10 @@ FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const p // Process a lambda expression . like lambda x,y: x + y AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Lambda"; - FunctionBlockPtr func_block = MakeFunctionBlock(*this); - func_block->AddPrevBlock(block); - func_block->Mature(); + FunctionBlockPtr function_block = ParseLambdaFunction(node, block); + MS_EXCEPTION_IF_NULL(function_block); - // Get lambda args - py::list args = ast_->GetArgs(node); - auto block_fg = func_block->func_graph(); - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg_name = py::cast(args[i].attr("arg")); - TraceGuard guard(GetLocation(args[i])); - auto para_node = std::make_shared(block_fg); - para_node->debug_info()->set_name(arg_name); - block_fg->add_parameter(para_node); - func_block->WriteVariable(arg_name, para_node); - MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; - } - - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); - block_fg->set_output(lambda_body_node); + auto block_fg = function_block->func_graph(); ValueNodePtr const_graph = NewValueNode(block_fg); return const_graph; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 54617b1fdb9..dd648554880 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -196,7 +196,9 @@ class Parser { // Generate argument default value for ast function node void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); // Parse ast function node - FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); + FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); + // Parse lambda function node + FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); // Parse ast statements FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); // Parse one ast statement node diff --git a/tests/ut/python/pipeline/parse/test_partial.py b/tests/ut/python/pipeline/parse/test_partial.py index ac7d50d6585..11c1bdf4cb5 100644 --- a/tests/ut/python/pipeline/parse/test_partial.py +++ b/tests/ut/python/pipeline/parse/test_partial.py @@ -22,7 +22,14 @@ from mindspore import nn, Tensor, context context.set_context(mode=context.GRAPH_MODE) + def test_partial_pos_arg(): + """ + Feature: ALL TO ALL + Description: test cases for partial_pos_arg + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -35,13 +42,31 @@ def test_partial_pos_arg(): ret = f(y, z) return ret + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self, x, y, z): + f = partial(self.show, x) + ret = f(y, z) + return ret + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) - net = Net() - net(x, y, z) + + for net in [Net(), Net2()]: + net(x, y, z) + def test_partial_key_ward_arg(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -54,13 +79,31 @@ def test_partial_key_ward_arg(): ret = f(y=y, z=z) return ret + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self, x, y, z): + f = partial(self.show, x=x) + ret = f(y=y, z=z) + return ret + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) - net = Net() - net(x, y, z) + + for net in [Net(), Net2()]: + net(x, y, z) + def test_partial_key_ward_arg_update(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_update + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -73,14 +116,31 @@ def test_partial_key_ward_arg_update(): ret = f(y=y, z=z) return ret + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self, x, y, z): + f = partial(self.show, x=x, y=y) + ret = f(y=y, z=z) + return ret + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) - net = Net() - net(x, y, z) + + for net in [Net(), Net2()]: + net(x, y, z) def test_partial_key_ward_arg_and_pos_arg(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_and_pos_arg + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -93,14 +153,31 @@ def test_partial_key_ward_arg_and_pos_arg(): ret = f(2, z=z) return ret + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self, x, y, z): + f = partial(self.show, y=y) + ret = f(2, z=z) + return ret + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) - net = Net() - net(x, y, z) + + for net in [Net(), Net2()]: + net(x, y, z) def test_partial_pos_arg_const(): + """ + Feature: ALL TO ALL + Description: test cases for partial_pos_arg_const + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -113,10 +190,27 @@ def test_partial_pos_arg_const(): ret = f(2, 3) return ret - net = Net() - assert net() == (1, 2, 3) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, 1) + ret = f(2, 3) + return ret + + for net in [Net(), Net2()]: + assert net() == (1, 2, 3) + def test_partial_key_ward_arg_const(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_const + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -129,10 +223,27 @@ def test_partial_key_ward_arg_const(): ret = f(y=2, z=3) return ret - net = Net() - assert net() == (1, 2, 3) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, x=1) + ret = f(y=2, z=3) + return ret + + for net in [Net(), Net2()]: + assert net() == (1, 2, 3) + def test_partial_key_ward_arg_update_const(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_update_const + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -145,11 +256,27 @@ def test_partial_key_ward_arg_update_const(): ret = f(y=3, z=4) return ret - net = Net() - assert net() == (1, 3, 4) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, x=1, y=2) + ret = f(y=3, z=4) + return ret + + for net in [Net(), Net2()]: + assert net() == (1, 3, 4) def test_partial_key_ward_arg_and_pos_arg_const(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_and_pos_arg_const + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -162,11 +289,27 @@ def test_partial_key_ward_arg_and_pos_arg_const(): ret = f(1, z=3) return ret - net = Net() - assert net() == (1, 2, 3) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, y=2) + ret = f(1, z=3) + return ret + + for net in [Net(), Net2()]: + assert net() == (1, 2, 3) def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -179,13 +322,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x(): ret = f(1, 2, 3) return ret - net = Net() - with pytest.raises(TypeError) as ex: - net() - assert "Multiply values for specific argument: x" in str(ex.value) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, x=1) + ret = f(1, 2, 3) + return ret + + for net in [Net(), Net2()]: + with pytest.raises(TypeError) as ex: + net() + assert "Multiply values for specific argument: x" in str(ex.value) def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -198,13 +357,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y(): ret = f(1, 2, z=3) return ret - net = Net() - with pytest.raises(TypeError) as ex: - net() - assert "Multiply values for specific argument: y" in str(ex.value) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, y=2) + ret = f(1, 2, z=3) + return ret + + for net in [Net(), Net2()]: + with pytest.raises(TypeError) as ex: + net() + assert "Multiply values for specific argument: y" in str(ex.value) def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z(): + """ + Feature: ALL TO ALL + Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z + Expectation: the result match given one + """ + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -217,7 +392,17 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z(): ret = f(1, 2, 3) return ret - net = Net() - with pytest.raises(TypeError) as ex: - net() - assert "Multiply values for specific argument: z" in str(ex.value) + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.show = lambda x, y, z: (x, y, z) + + def construct(self): + f = partial(self.show, z=1) + ret = f(1, 2, 3) + return ret + + for net in [Net(), Net2()]: + with pytest.raises(TypeError) as ex: + net() + assert "Multiply values for specific argument: z" in str(ex.value)