forked from OSSInnovation/mindspore
!6505 Set top graph parameters' name the same as original graph parameters.
Merge pull request !6505 from 张清华/master2
This commit is contained in:
commit
dfe77372f5
|
@ -1790,38 +1790,92 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Generate and copy a ValueNode, or a CNode with its child nodes
|
||||
static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr ¶m_node) {
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->isa<ValueNode>()) {
|
||||
return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
|
||||
}
|
||||
|
||||
// Parameter default value is CNode.
|
||||
std::size_t index = 0;
|
||||
std::vector<AnfNodePtr> old_cnodes;
|
||||
old_cnodes.emplace_back(param_node);
|
||||
auto res = func_graph->NewCNode({});
|
||||
std::vector<CNodePtr> new_cnodes;
|
||||
new_cnodes.emplace_back(res);
|
||||
while (index < old_cnodes.size()) {
|
||||
auto current = old_cnodes[index];
|
||||
auto current_new_cnode = new_cnodes[index];
|
||||
index++;
|
||||
MS_EXCEPTION_IF_NULL(current);
|
||||
if (current->isa<CNode>()) {
|
||||
auto &inputs = current->cast<CNodePtr>()->inputs();
|
||||
for (auto it = inputs.begin(); it != inputs.end(); it++) {
|
||||
AnfNodePtr input = *it;
|
||||
if (input != nullptr && input->isa<CNode>()) {
|
||||
old_cnodes.emplace_back(input);
|
||||
auto new_cnode = func_graph->NewCNode({});
|
||||
new_cnodes.emplace_back(new_cnode);
|
||||
current_new_cnode->add_input(new_cnode);
|
||||
} else if (input->isa<ValueNode>()) {
|
||||
current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
||||
auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
|
||||
if (current_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
|
||||
}
|
||||
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
func_graph->debug_info()->set_name("top");
|
||||
func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
|
||||
|
||||
// def top(*arg, *kwargs):
|
||||
auto param_vargs = func_graph->add_parameter();
|
||||
auto args_name = "args";
|
||||
param_vargs->set_name(args_name);
|
||||
param_vargs->debug_info()->set_name(args_name);
|
||||
|
||||
auto param_vkwargs = func_graph->add_parameter();
|
||||
args_name = "kwargs";
|
||||
param_vkwargs->set_name(args_name);
|
||||
param_vkwargs->debug_info()->set_name(args_name);
|
||||
|
||||
func_graph->set_has_vararg(true);
|
||||
func_graph->set_has_kwarg(true);
|
||||
func_graph->set_kwonlyargs_count(0);
|
||||
// Copy all parameters information
|
||||
for (auto ¶ : current_graph->parameters()) {
|
||||
auto param = func_graph->add_parameter();
|
||||
auto orig_param = para->cast<ParameterPtr>();
|
||||
auto name = orig_param->name();
|
||||
param->set_name(name);
|
||||
param->debug_info()->set_name(name);
|
||||
}
|
||||
func_graph->set_has_vararg(current_graph->has_vararg());
|
||||
func_graph->set_has_kwarg(current_graph->has_kwarg());
|
||||
func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
|
||||
// Copy all default values
|
||||
for (auto &d : current_graph->parameter_default_value()) {
|
||||
func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
|
||||
}
|
||||
|
||||
// cell_obj
|
||||
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
|
||||
parse::UpdateFuncGraphFlags(cell, func_graph);
|
||||
// top graph's construct flag
|
||||
if (py::hasattr(cell, "construct")) {
|
||||
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
|
||||
}
|
||||
|
||||
// ret = cell_obj(*arg, *kwargs)
|
||||
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});
|
||||
|
||||
// return ret
|
||||
func_graph->set_output(call_fn);
|
||||
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
|
||||
auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
|
||||
if (!unpacking) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(cell_ptr));
|
||||
auto ¶ms = func_graph->parameters();
|
||||
(void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
|
||||
[](AnfNodePtr node) -> AnfNodePtr { return node; });
|
||||
func_graph->set_output(func_graph->NewCNode(inputs));
|
||||
} else {
|
||||
// ret = cell_obj(*arg, *kwargs)
|
||||
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
|
||||
// return ret
|
||||
func_graph->set_output(call_fn);
|
||||
}
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace parse
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_get_parameter_layout():
|
|||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'args0': x_layout, 'w1': weight_layout}
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ def test_grad_sens_parameter_type():
|
|||
y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]]
|
||||
b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]]
|
||||
sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]]
|
||||
expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout}
|
||||
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
||||
|
||||
|
|
|
@ -45,7 +45,6 @@ log = logging.getLogger("test")
|
|||
log.setLevel(level=logging.ERROR)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
# Test case: use the parse obj interface use default parameter
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
@ -55,7 +54,7 @@ class Net(nn.Cell):
|
|||
self.softmax1 = nn.Softmax(dim)
|
||||
self.softmax2 = nn.Softmax(dim + 1)
|
||||
|
||||
def construct(self, input_data, input1=ms.Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))):
|
||||
def construct(self, input_data, input1=1+2+3+4):
|
||||
return self.softmax1(input_data)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue