improve the way passing ags of partial

This commit is contained in:
buxue 2020-09-25 17:02:11 +08:00
parent 5864cbc29d
commit 559c741cce
3 changed files with 249 additions and 17 deletions

View File

@ -915,7 +915,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n
AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast List";
MS_EXCEPTION_IF_NULL(block);
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
py::list elts = python_adapter::GetPyObjAttr(node, "elts");
if (elts.size() == 0) {
auto empty_list = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueList>(empty_list));

View File

@ -126,6 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::set<AnfNodePtr> key_ward_para_nodes;
for (const auto &kwarg : kwarg_list) {
MS_EXCEPTION_IF_NULL(kwarg);
std::string kw_param_name = kwarg->get_key();
@ -146,7 +147,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
return param != nullptr && param->name() == param_name;
});
if (find_kw_arg_in_list) {
MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name;
MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name;
}
p->set_name(param_name);
p->debug_info()->set_name(param_name);
@ -159,12 +160,14 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
} else {
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
// multiply values found given for parameter
if (node_itr != specialized_parameter_list->end()) {
MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name;
if (node_itr != specialized_parameter_list->end() &&
key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) {
MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
} else {
specialized_parameter_list->push_back(param_node);
auto extract_node = specialized_graph->NewCNode(
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
key_ward_para_nodes.insert(param_node);
(void)repl_nodes->emplace(param_node, extract_node);
}
}
@ -199,10 +202,7 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>
}
// if the graph is generated for specific input, do not need to generate again
if (is_generated()) {
return false;
}
return true;
return !is_generated();
}
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
@ -232,20 +232,23 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
std::vector<size_t> pos_arg_indexes;
size_t arguments_count = args_spec_list.size();
for (const auto &arg : args_spec_list) {
// if it is a keyword argument
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<abstract::AbstractKeywordArg>()) {
kwarg_list.push_back(dyn_cast<abstract::AbstractKeywordArg>(arg));
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
} else {
pos_arg_indexes.push_back(i);
}
}
if (!NeedGenerate(kwarg_list)) {
return shared_from_base<FuncGraph>();
}
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count());
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
int variable_args_count = pos_args_input_count - pos_args_count;
std::vector<AnfNodePtr> specialized_parameter_list;
@ -265,8 +268,14 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
// append hyper parameter to specialized_parameter_list
MS_EXCEPTION_IF_NULL(specialized_graph);
auto params = specialized_graph->parameters();
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_),
params.end());
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
specialized_parameter_list.end());
for (size_t i = 0; i < pos_arg_indexes.size(); i++) {
specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i],
specialized_parameter_list[i]);
}
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
auto tr = manager->Transact();
@ -275,7 +284,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
<< node_pair.second->DebugString();
(void)tr.Replace(node_pair.first, node_pair.second);
}
tr.SetParameters(specialized_graph, specialized_parameter_list);
tr.SetParameters(specialized_graph, specialized_parameter_list_update);
tr.Commit();
specialized_graph->set_has_kwarg(false);
specialized_graph->set_has_vararg(false);

View File

@ -0,0 +1,223 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test partial"""
from functools import partial
import numpy as np
import pytest
from mindspore import nn, Tensor, context
context.set_context(mode=context.GRAPH_MODE)
def test_partial_pos_arg():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self, x, y, z):
f = partial(self.show, x)
ret = f(y, z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
def test_partial_key_ward_arg():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self, x, y, z):
f = partial(self.show, x=x)
ret = f(y=y, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
def test_partial_key_ward_arg_update():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self, x, y, z):
f = partial(self.show, x=x, y=y)
ret = f(y=y, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
def test_partial_key_ward_arg_and_pos_arg():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self, x, y, z):
f = partial(self.show, y=y)
ret = f(2, z=z)
return ret
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)
def test_partial_pos_arg_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, 1)
ret = f(2, 3)
return ret
net = Net()
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, x=1)
ret = f(y=2, z=3)
return ret
net = Net()
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_update_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, x=1, y=2)
ret = f(y=3, z=4)
return ret
net = Net()
assert net() == (1, 3, 4)
def test_partial_key_ward_arg_and_pos_arg_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, y=2)
ret = f(1, z=3)
return ret
net = Net()
assert net() == (1, 2, 3)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, x=1)
ret = f(1, 2, 3)
return ret
net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: x" in str(ex.value)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, y=2)
ret = f(1, 2, z=3)
return ret
net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: y" in str(ex.value)
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def show(self, x, y, z):
return x, y, z
def construct(self):
f = partial(self.show, z=1)
ret = f(1, 2, 3)
return ret
net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: z" in str(ex.value)