forked from mindspore-Ecosystem/mindspore
improve the way passing ags of partial
This commit is contained in:
parent
5864cbc29d
commit
559c741cce
|
@ -915,7 +915,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n
|
|||
AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast List";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
py::list elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
if (elts.size() == 0) {
|
||||
auto empty_list = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueList>(empty_list));
|
||||
|
|
|
@ -126,6 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
|||
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
|
||||
std::set<AnfNodePtr> key_ward_para_nodes;
|
||||
for (const auto &kwarg : kwarg_list) {
|
||||
MS_EXCEPTION_IF_NULL(kwarg);
|
||||
std::string kw_param_name = kwarg->get_key();
|
||||
|
@ -146,7 +147,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
|||
return param != nullptr && param->name() == param_name;
|
||||
});
|
||||
if (find_kw_arg_in_list) {
|
||||
MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name;
|
||||
MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name;
|
||||
}
|
||||
p->set_name(param_name);
|
||||
p->debug_info()->set_name(param_name);
|
||||
|
@ -159,12 +160,14 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
|||
} else {
|
||||
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
|
||||
// multiply values found given for parameter
|
||||
if (node_itr != specialized_parameter_list->end()) {
|
||||
MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name;
|
||||
if (node_itr != specialized_parameter_list->end() &&
|
||||
key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) {
|
||||
MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
|
||||
} else {
|
||||
specialized_parameter_list->push_back(param_node);
|
||||
auto extract_node = specialized_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
|
||||
key_ward_para_nodes.insert(param_node);
|
||||
(void)repl_nodes->emplace(param_node, extract_node);
|
||||
}
|
||||
}
|
||||
|
@ -199,10 +202,7 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>
|
|||
}
|
||||
|
||||
// if the graph is generated for specific input, do not need to generate again
|
||||
if (is_generated()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return !is_generated();
|
||||
}
|
||||
|
||||
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
|
||||
|
@ -232,20 +232,23 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
|
|||
|
||||
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
||||
std::vector<size_t> pos_arg_indexes;
|
||||
size_t arguments_count = args_spec_list.size();
|
||||
for (const auto &arg : args_spec_list) {
|
||||
// if it is a keyword argument
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<abstract::AbstractKeywordArg>()) {
|
||||
kwarg_list.push_back(dyn_cast<abstract::AbstractKeywordArg>(arg));
|
||||
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
|
||||
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
|
||||
} else {
|
||||
pos_arg_indexes.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!NeedGenerate(kwarg_list)) {
|
||||
return shared_from_base<FuncGraph>();
|
||||
}
|
||||
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
|
||||
size_t kwarg_count = kwarg_list.size();
|
||||
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count());
|
||||
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
|
||||
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
|
||||
int variable_args_count = pos_args_input_count - pos_args_count;
|
||||
std::vector<AnfNodePtr> specialized_parameter_list;
|
||||
|
@ -265,8 +268,14 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
|||
// append hyper parameter to specialized_parameter_list
|
||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||
auto params = specialized_graph->parameters();
|
||||
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
|
||||
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
|
||||
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_),
|
||||
params.end());
|
||||
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
|
||||
specialized_parameter_list.end());
|
||||
for (size_t i = 0; i < pos_arg_indexes.size(); i++) {
|
||||
specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i],
|
||||
specialized_parameter_list[i]);
|
||||
}
|
||||
|
||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
|
||||
auto tr = manager->Transact();
|
||||
|
@ -275,7 +284,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
|||
<< node_pair.second->DebugString();
|
||||
(void)tr.Replace(node_pair.first, node_pair.second);
|
||||
}
|
||||
tr.SetParameters(specialized_graph, specialized_parameter_list);
|
||||
tr.SetParameters(specialized_graph, specialized_parameter_list_update);
|
||||
tr.Commit();
|
||||
specialized_graph->set_has_kwarg(false);
|
||||
specialized_graph->set_has_vararg(false);
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test partial"""
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import nn, Tensor, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def test_partial_pos_arg():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, x)
|
||||
ret = f(y, z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
def test_partial_key_ward_arg():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, x=x)
|
||||
ret = f(y=y, z=z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
def test_partial_key_ward_arg_update():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, x=x, y=y)
|
||||
ret = f(y=y, z=z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self, x, y, z):
|
||||
f = partial(self.show, y=y)
|
||||
ret = f(2, z=z)
|
||||
return ret
|
||||
|
||||
x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
|
||||
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
|
||||
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
|
||||
net = Net()
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_partial_pos_arg_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, 1)
|
||||
ret = f(2, 3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
|
||||
def test_partial_key_ward_arg_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, x=1)
|
||||
ret = f(y=2, z=3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
|
||||
def test_partial_key_ward_arg_update_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, x=1, y=2)
|
||||
ret = f(y=3, z=4)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 3, 4)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, y=2)
|
||||
ret = f(1, z=3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, x=1)
|
||||
ret = f(1, 2, 3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: x" in str(ex.value)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, y=2)
|
||||
ret = f(1, 2, z=3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: y" in str(ex.value)
|
||||
|
||||
|
||||
def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def show(self, x, y, z):
|
||||
return x, y, z
|
||||
|
||||
def construct(self):
|
||||
f = partial(self.show, z=1)
|
||||
ret = f(1, 2, 3)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as ex:
|
||||
net()
|
||||
assert "Multiply values for specific argument: z" in str(ex.value)
|
Loading…
Reference in New Issue