forked from mindspore-Ecosystem/mindspore
!26690 Add supoort resolving outer lambda function for ops.Partial.
Merge pull request !26690 from hezhenhao1/add_lambda
This commit is contained in:
commit
d1e4e674ab
|
@ -175,7 +175,19 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
|
||||||
FuncGraphPtr Parser::ParseFuncGraph() {
|
FuncGraphPtr Parser::ParseFuncGraph() {
|
||||||
// Get ast FunctionDef node
|
// Get ast FunctionDef node
|
||||||
py::object node = ast_->GetAstNode();
|
py::object node = ast_->GetAstNode();
|
||||||
FunctionBlockPtr fn_block = ParseFunction(node);
|
constexpr char function_def_name[] = "FunctionDef";
|
||||||
|
constexpr char lambda_name[] = "Lambda";
|
||||||
|
FunctionBlockPtr fn_block = nullptr;
|
||||||
|
if (ast_->GetNodeType(node)->node_name() == function_def_name) {
|
||||||
|
fn_block = ParseDefFunction(node);
|
||||||
|
} else {
|
||||||
|
auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
|
||||||
|
if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
|
||||||
|
MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
|
||||||
|
<< ast_->GetNodeType(lambda_node)->node_name() << ".";
|
||||||
|
}
|
||||||
|
fn_block = ParseLambdaFunction(lambda_node);
|
||||||
|
}
|
||||||
if (errcode() != PARSE_SUCCESS) {
|
if (errcode() != PARSE_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
|
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -259,7 +271,7 @@ ScopePtr Parser::GetScopeForParseFunction() {
|
||||||
return scope;
|
return scope;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
|
FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
|
||||||
ScopePtr scope = GetScopeForParseFunction();
|
ScopePtr scope = GetScopeForParseFunction();
|
||||||
// The node created in the parsefunction context, will inherit the scope created using scope_guard
|
// The node created in the parsefunction context, will inherit the scope created using scope_guard
|
||||||
ScopeGuard scope_guard(scope);
|
ScopeGuard scope_guard(scope);
|
||||||
|
@ -323,6 +335,33 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
||||||
return func_block;
|
return func_block;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
|
||||||
|
MS_EXCEPTION_IF_NULL(ast_);
|
||||||
|
ScopePtr scope = GetScopeForParseFunction();
|
||||||
|
ScopeGuard scope_guard(scope);
|
||||||
|
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
|
||||||
|
|
||||||
|
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
|
||||||
|
if (block != nullptr) {
|
||||||
|
func_block->AddPrevBlock(block);
|
||||||
|
} else {
|
||||||
|
func_graph_ = func_block->func_graph();
|
||||||
|
}
|
||||||
|
func_block->Mature();
|
||||||
|
auto current_fg = func_block->func_graph();
|
||||||
|
|
||||||
|
auto function_name = ast_->function_name();
|
||||||
|
MS_LOG(DEBUG) << "The function name is " << function_name;
|
||||||
|
current_fg->debug_info()->set_name(function_name);
|
||||||
|
GenerateArgsNodeForFunction(func_block, node);
|
||||||
|
|
||||||
|
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||||
|
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
|
||||||
|
current_fg->set_output(lambda_body_node);
|
||||||
|
GenerateArgsDefaultValueForFunction(func_block, node);
|
||||||
|
return func_block;
|
||||||
|
}
|
||||||
|
|
||||||
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
|
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
|
||||||
auto node_list = py::cast<py::list>(nodes);
|
auto node_list = py::cast<py::list>(nodes);
|
||||||
size_t count = py::len(node_list);
|
size_t count = py::len(node_list);
|
||||||
|
@ -909,7 +948,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
|
||||||
// Process a function def
|
// Process a function def
|
||||||
FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
|
FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
|
||||||
MS_LOG(DEBUG) << "Process ast FunctionDef";
|
MS_LOG(DEBUG) << "Process ast FunctionDef";
|
||||||
FunctionBlockPtr function_block = ParseFunction(node, block);
|
FunctionBlockPtr function_block = ParseDefFunction(node, block);
|
||||||
MS_EXCEPTION_IF_NULL(function_block);
|
MS_EXCEPTION_IF_NULL(function_block);
|
||||||
|
|
||||||
// Get function name
|
// Get function name
|
||||||
|
@ -923,26 +962,10 @@ FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const p
|
||||||
// Process a lambda expression . like lambda x,y: x + y
|
// Process a lambda expression . like lambda x,y: x + y
|
||||||
AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
|
AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
|
||||||
MS_LOG(DEBUG) << "Process ast Lambda";
|
MS_LOG(DEBUG) << "Process ast Lambda";
|
||||||
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
|
FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
|
||||||
func_block->AddPrevBlock(block);
|
MS_EXCEPTION_IF_NULL(function_block);
|
||||||
func_block->Mature();
|
|
||||||
|
|
||||||
// Get lambda args
|
auto block_fg = function_block->func_graph();
|
||||||
py::list args = ast_->GetArgs(node);
|
|
||||||
auto block_fg = func_block->func_graph();
|
|
||||||
for (std::size_t i = 0; i < args.size(); i++) {
|
|
||||||
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
|
||||||
TraceGuard guard(GetLocation(args[i]));
|
|
||||||
auto para_node = std::make_shared<Parameter>(block_fg);
|
|
||||||
para_node->debug_info()->set_name(arg_name);
|
|
||||||
block_fg->add_parameter(para_node);
|
|
||||||
func_block->WriteVariable(arg_name, para_node);
|
|
||||||
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
|
|
||||||
}
|
|
||||||
|
|
||||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
|
||||||
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
|
|
||||||
block_fg->set_output(lambda_body_node);
|
|
||||||
ValueNodePtr const_graph = NewValueNode(block_fg);
|
ValueNodePtr const_graph = NewValueNode(block_fg);
|
||||||
return const_graph;
|
return const_graph;
|
||||||
}
|
}
|
||||||
|
|
|
@ -196,7 +196,9 @@ class Parser {
|
||||||
// Generate argument default value for ast function node
|
// Generate argument default value for ast function node
|
||||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||||
// Parse ast function node
|
// Parse ast function node
|
||||||
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||||
|
// Parse lambda function node
|
||||||
|
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||||
// Parse ast statements
|
// Parse ast statements
|
||||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
|
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
|
||||||
// Parse one ast statement node
|
// Parse one ast statement node
|
||||||
|
|
|
@ -22,7 +22,14 @@ from mindspore import nn, Tensor, context
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_pos_arg():
|
def test_partial_pos_arg():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_pos_arg
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -35,13 +42,31 @@ def test_partial_pos_arg():
|
||||||
ret = f(y, z)
|
ret = f(y, z)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
class Net2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
f = partial(self.show, x)
|
||||||
|
ret = f(y, z)
|
||||||
|
return ret
|
||||||
|
|
||||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||||
net = Net()
|
|
||||||
net(x, y, z)
|
for net in [Net(), Net2()]:
|
||||||
|
net(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg():
|
def test_partial_key_ward_arg():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -54,13 +79,31 @@ def test_partial_key_ward_arg():
|
||||||
ret = f(y=y, z=z)
|
ret = f(y=y, z=z)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
class Net2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
f = partial(self.show, x=x)
|
||||||
|
ret = f(y=y, z=z)
|
||||||
|
return ret
|
||||||
|
|
||||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||||
net = Net()
|
|
||||||
net(x, y, z)
|
for net in [Net(), Net2()]:
|
||||||
|
net(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_update():
|
def test_partial_key_ward_arg_update():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_update
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -73,14 +116,31 @@ def test_partial_key_ward_arg_update():
|
||||||
ret = f(y=y, z=z)
|
ret = f(y=y, z=z)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
class Net2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
f = partial(self.show, x=x, y=y)
|
||||||
|
ret = f(y=y, z=z)
|
||||||
|
return ret
|
||||||
|
|
||||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||||
net = Net()
|
|
||||||
net(x, y, z)
|
for net in [Net(), Net2()]:
|
||||||
|
net(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_and_pos_arg():
|
def test_partial_key_ward_arg_and_pos_arg():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_and_pos_arg
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -93,14 +153,31 @@ def test_partial_key_ward_arg_and_pos_arg():
|
||||||
ret = f(2, z=z)
|
ret = f(2, z=z)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
class Net2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
f = partial(self.show, y=y)
|
||||||
|
ret = f(2, z=z)
|
||||||
|
return ret
|
||||||
|
|
||||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||||
net = Net()
|
|
||||||
net(x, y, z)
|
for net in [Net(), Net2()]:
|
||||||
|
net(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_pos_arg_const():
|
def test_partial_pos_arg_const():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_pos_arg_const
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -113,10 +190,27 @@ def test_partial_pos_arg_const():
|
||||||
ret = f(2, 3)
|
ret = f(2, 3)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
assert net() == (1, 2, 3)
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, 1)
|
||||||
|
ret = f(2, 3)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
assert net() == (1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_const():
|
def test_partial_key_ward_arg_const():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_const
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -129,10 +223,27 @@ def test_partial_key_ward_arg_const():
|
||||||
ret = f(y=2, z=3)
|
ret = f(y=2, z=3)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
assert net() == (1, 2, 3)
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, x=1)
|
||||||
|
ret = f(y=2, z=3)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
assert net() == (1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_update_const():
|
def test_partial_key_ward_arg_update_const():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_update_const
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -145,11 +256,27 @@ def test_partial_key_ward_arg_update_const():
|
||||||
ret = f(y=3, z=4)
|
ret = f(y=3, z=4)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
assert net() == (1, 3, 4)
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, x=1, y=2)
|
||||||
|
ret = f(y=3, z=4)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
assert net() == (1, 3, 4)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_and_pos_arg_const():
|
def test_partial_key_ward_arg_and_pos_arg_const():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_and_pos_arg_const
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -162,11 +289,27 @@ def test_partial_key_ward_arg_and_pos_arg_const():
|
||||||
ret = f(1, z=3)
|
ret = f(1, z=3)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
assert net() == (1, 2, 3)
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, y=2)
|
||||||
|
ret = f(1, z=3)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
assert net() == (1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
|
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -179,13 +322,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
|
||||||
ret = f(1, 2, 3)
|
ret = f(1, 2, 3)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
with pytest.raises(TypeError) as ex:
|
def __init__(self):
|
||||||
net()
|
super(Net2, self).__init__()
|
||||||
assert "Multiply values for specific argument: x" in str(ex.value)
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, x=1)
|
||||||
|
ret = f(1, 2, 3)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
with pytest.raises(TypeError) as ex:
|
||||||
|
net()
|
||||||
|
assert "Multiply values for specific argument: x" in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
|
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -198,13 +357,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
|
||||||
ret = f(1, 2, z=3)
|
ret = f(1, 2, z=3)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
with pytest.raises(TypeError) as ex:
|
def __init__(self):
|
||||||
net()
|
super(Net2, self).__init__()
|
||||||
assert "Multiply values for specific argument: y" in str(ex.value)
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, y=2)
|
||||||
|
ret = f(1, 2, z=3)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
with pytest.raises(TypeError) as ex:
|
||||||
|
net()
|
||||||
|
assert "Multiply values for specific argument: y" in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
|
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z
|
||||||
|
Expectation: the result match given one
|
||||||
|
"""
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -217,7 +392,17 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
|
||||||
ret = f(1, 2, 3)
|
ret = f(1, 2, 3)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
net = Net()
|
class Net2(nn.Cell):
|
||||||
with pytest.raises(TypeError) as ex:
|
def __init__(self):
|
||||||
net()
|
super(Net2, self).__init__()
|
||||||
assert "Multiply values for specific argument: z" in str(ex.value)
|
self.show = lambda x, y, z: (x, y, z)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
f = partial(self.show, z=1)
|
||||||
|
ret = f(1, 2, 3)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for net in [Net(), Net2()]:
|
||||||
|
with pytest.raises(TypeError) as ex:
|
||||||
|
net()
|
||||||
|
assert "Multiply values for specific argument: z" in str(ex.value)
|
||||||
|
|
Loading…
Reference in New Issue