diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index e44c6a5b3b..1dcbc0814b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1790,38 +1790,92 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { return true; } +// Generate and copy a ValueNode, or a CNode with its child nodes +static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr ¶m_node) { + MS_EXCEPTION_IF_NULL(param_node); + if (param_node->isa()) { + return std::make_shared(param_node->cast()->value()); + } + + // Parameter default value is CNode. + std::size_t index = 0; + std::vector old_cnodes; + old_cnodes.emplace_back(param_node); + auto res = func_graph->NewCNode({}); + std::vector new_cnodes; + new_cnodes.emplace_back(res); + while (index < old_cnodes.size()) { + auto current = old_cnodes[index]; + auto current_new_cnode = new_cnodes[index]; + index++; + MS_EXCEPTION_IF_NULL(current); + if (current->isa()) { + auto &inputs = current->cast()->inputs(); + for (auto it = inputs.begin(); it != inputs.end(); it++) { + AnfNodePtr input = *it; + if (input != nullptr && input->isa()) { + old_cnodes.emplace_back(input); + auto new_cnode = func_graph->NewCNode({}); + new_cnodes.emplace_back(new_cnode); + current_new_cnode->add_input(new_cnode); + } else if (input->isa()) { + current_new_cnode->add_input(std::make_shared(input->cast()->value())); + } else { + MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString(); + } + } + } + } + return res; +} + FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { + auto current_graph = dyn_cast(cell_ptr); + if (current_graph == nullptr) { + MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString(); + } + auto func_graph = std::make_shared(); - func_graph->debug_info()->set_name("top"); + func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper"); - // def top(*arg, *kwargs): - auto param_vargs = func_graph->add_parameter(); - auto args_name = "args"; - param_vargs->set_name(args_name); - param_vargs->debug_info()->set_name(args_name); - - auto param_vkwargs = func_graph->add_parameter(); - args_name = "kwargs"; - param_vkwargs->set_name(args_name); - param_vkwargs->debug_info()->set_name(args_name); - - func_graph->set_has_vararg(true); - func_graph->set_has_kwarg(true); - func_graph->set_kwonlyargs_count(0); + // Copy all parameters information + for (auto ¶ : current_graph->parameters()) { + auto param = func_graph->add_parameter(); + auto orig_param = para->cast(); + auto name = orig_param->name(); + param->set_name(name); + param->debug_info()->set_name(name); + } + func_graph->set_has_vararg(current_graph->has_vararg()); + func_graph->set_has_kwarg(current_graph->has_kwarg()); + func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count()); + // Copy all default values + for (auto &d : current_graph->parameter_default_value()) { + func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second)); + } // cell_obj + MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); parse::UpdateFuncGraphFlags(cell, func_graph); // top graph's construct flag if (py::hasattr(cell, "construct")) { parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); } - // ret = cell_obj(*arg, *kwargs) - auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs}); - - // return ret - func_graph->set_output(call_fn); - MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); + auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg(); + if (!unpacking) { + std::vector inputs; + inputs.emplace_back(NewValueNode(cell_ptr)); + auto ¶ms = func_graph->parameters(); + (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs), + [](AnfNodePtr node) -> AnfNodePtr { return node; }); + func_graph->set_output(func_graph->NewCNode(inputs)); + } else { + // ret = cell_obj(*arg, *kwargs) + auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters()); + // return ret + func_graph->set_output(call_fn); + } return func_graph; } } // namespace parse diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index 0d4aa1722d..c494161b90 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -51,7 +51,7 @@ def test_get_parameter_layout(): exe.compile(net, x, phase='train', auto_parallel_mode=True) x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] - expect_dict = {'args0': x_layout, 'w1': weight_layout} + expect_dict = {'x': x_layout, 'w1': weight_layout} # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut assert net.parameter_layout_dict == expect_dict diff --git a/tests/ut/python/parallel/test_split_grad_sens.py b/tests/ut/python/parallel/test_split_grad_sens.py index e0e01adcb7..ee5d1c48d0 100644 --- a/tests/ut/python/parallel/test_split_grad_sens.py +++ b/tests/ut/python/parallel/test_split_grad_sens.py @@ -125,7 +125,7 @@ def test_grad_sens_parameter_type(): y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] - expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout} + expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} assert net.parameter_layout_dict == expect_dict diff --git a/tests/ut/python/pipeline/parse/test_parse.py b/tests/ut/python/pipeline/parse/test_parse.py index 5183d1d55f..fa265d2d24 100644 --- a/tests/ut/python/pipeline/parse/test_parse.py +++ b/tests/ut/python/pipeline/parse/test_parse.py @@ -45,7 +45,6 @@ log = logging.getLogger("test") log.setLevel(level=logging.ERROR) context.set_context(mode=context.GRAPH_MODE) - # Test case: use the parse obj interface use default parameter class Net(nn.Cell): """ Net definition """ @@ -55,7 +54,7 @@ class Net(nn.Cell): self.softmax1 = nn.Softmax(dim) self.softmax2 = nn.Softmax(dim + 1) - def construct(self, input_data, input1=ms.Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))): + def construct(self, input_data, input1=1+2+3+4): return self.softmax1(input_data)