!6505 Set top graph parameters' name the same as original graph parameters.

Merge pull request !6505 from 张清华/master2
This commit is contained in:
mindspore-ci-bot 2020-09-21 20:06:59 +08:00 committed by Gitee
commit dfe77372f5
4 changed files with 78 additions and 25 deletions

View File

@ -1790,38 +1790,92 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
return true;
}
// Generate and copy a ValueNode, or a CNode with its child nodes
static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->isa<ValueNode>()) {
return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
}
// Parameter default value is CNode.
std::size_t index = 0;
std::vector<AnfNodePtr> old_cnodes;
old_cnodes.emplace_back(param_node);
auto res = func_graph->NewCNode({});
std::vector<CNodePtr> new_cnodes;
new_cnodes.emplace_back(res);
while (index < old_cnodes.size()) {
auto current = old_cnodes[index];
auto current_new_cnode = new_cnodes[index];
index++;
MS_EXCEPTION_IF_NULL(current);
if (current->isa<CNode>()) {
auto &inputs = current->cast<CNodePtr>()->inputs();
for (auto it = inputs.begin(); it != inputs.end(); it++) {
AnfNodePtr input = *it;
if (input != nullptr && input->isa<CNode>()) {
old_cnodes.emplace_back(input);
auto new_cnode = func_graph->NewCNode({});
new_cnodes.emplace_back(new_cnode);
current_new_cnode->add_input(new_cnode);
} else if (input->isa<ValueNode>()) {
current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
} else {
MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
}
}
}
}
return res;
}
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
if (current_graph == nullptr) {
MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
}
auto func_graph = std::make_shared<FuncGraph>();
func_graph->debug_info()->set_name("top");
func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
// def top(*arg, *kwargs):
auto param_vargs = func_graph->add_parameter();
auto args_name = "args";
param_vargs->set_name(args_name);
param_vargs->debug_info()->set_name(args_name);
auto param_vkwargs = func_graph->add_parameter();
args_name = "kwargs";
param_vkwargs->set_name(args_name);
param_vkwargs->debug_info()->set_name(args_name);
func_graph->set_has_vararg(true);
func_graph->set_has_kwarg(true);
func_graph->set_kwonlyargs_count(0);
// Copy all parameters information
for (auto &para : current_graph->parameters()) {
auto param = func_graph->add_parameter();
auto orig_param = para->cast<ParameterPtr>();
auto name = orig_param->name();
param->set_name(name);
param->debug_info()->set_name(name);
}
func_graph->set_has_vararg(current_graph->has_vararg());
func_graph->set_has_kwarg(current_graph->has_kwarg());
func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
// Copy all default values
for (auto &d : current_graph->parameter_default_value()) {
func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
}
// cell_obj
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
parse::UpdateFuncGraphFlags(cell, func_graph);
// top graph's construct flag
if (py::hasattr(cell, "construct")) {
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
}
// ret = cell_obj(*arg, *kwargs)
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});
// return ret
func_graph->set_output(call_fn);
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
if (!unpacking) {
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(cell_ptr));
auto &params = func_graph->parameters();
(void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
[](AnfNodePtr node) -> AnfNodePtr { return node; });
func_graph->set_output(func_graph->NewCNode(inputs));
} else {
// ret = cell_obj(*arg, *kwargs)
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
// return ret
func_graph->set_output(call_fn);
}
return func_graph;
}
} // namespace parse

View File

@ -51,7 +51,7 @@ def test_get_parameter_layout():
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'args0': x_layout, 'w1': weight_layout}
expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict

View File

@ -125,7 +125,7 @@ def test_grad_sens_parameter_type():
y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]]
b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]]
sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]]
expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout}
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
assert net.parameter_layout_dict == expect_dict

View File

@ -45,7 +45,6 @@ log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
context.set_context(mode=context.GRAPH_MODE)
# Test case: use the parse obj interface use default parameter
class Net(nn.Cell):
""" Net definition """
@ -55,7 +54,7 @@ class Net(nn.Cell):
self.softmax1 = nn.Softmax(dim)
self.softmax2 = nn.Softmax(dim + 1)
def construct(self, input_data, input1=ms.Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))):
def construct(self, input_data, input1=1+2+3+4):
return self.softmax1(input_data)