diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 9d9b2094af1..c7cb2cb3127 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1031,7 +1031,8 @@ PynativeExecutor::PynativeExecutor() { void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); - if (cell_graph_map_.count(cell_id) != 0) { + // judge graph_context_.empty() to create sperate graphs except for the top + if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) { if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { resource_ = cell_resource_map_[cell_id]; } @@ -1040,21 +1041,24 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg } auto g = std::make_shared(); - - if (top_g_ == nullptr) { + if (graph_context_.empty()) { + // a df builder is built for every top function graph + df_builder_ = std::make_shared(); + df_builder_map_[cell_id] = df_builder_; top_g_ = curr_g_ = g; resource_ = std::make_shared(); resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; cell_resource_map_[cell_id] = resource_; - df_builder_ = std::make_shared(); MS_LOG(DEBUG) << "First new graph" << top_g_.get(); first_grad_step_ = true; top_graph_cells_.insert(cell_id); - Pushp(); } else { - Pushp(); + if (df_builder_ == nullptr) { + MS_LOG(EXCEPTION) << "In NewGraphInner, got df builder is nullptr"; + } curr_g_ = g; } + Pushp(); if (graph_info_map_.count(g) == 0) { graph_info_map_[g] = GraphInfo(); } @@ -1171,22 +1175,25 @@ void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr &pa } } -void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } +void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); } void PynativeExecutor::Popp() { - if (graph_p_.empty()) { - MS_LOG(EXCEPTION) << "Stack graph_p_ is empty"; + if (graph_context_.empty()) { + MS_LOG(EXCEPTION) << "Stack graph_context_ is empty"; + } + graph_context_.pop(); + if (!graph_context_.empty()) { + curr_g_ = graph_context_.top(); } - curr_g_ = graph_p_.top(); - graph_p_.pop(); } void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { auto cell_id = GetCellId(cell, args); - if (cell_graph_map_.count(cell_id) != 0) { + if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) { MS_LOG(DEBUG) << "Endgraph already compiled"; return; } + cell_graph_map_[cell_id] = curr_g_; auto out_id = GetId(out); if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { @@ -1246,7 +1253,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); } } - auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); + auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1); if (need_replace_param) { auto params = newfg->parameters(); auto manager = Manage({newfg}, false); @@ -1257,26 +1264,29 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje } } graph_info_map_.erase(curr_g_); - if (curr_g_ != top_g_) { + if (graph_context_.size() > 1) { Popp(); + // connect the previous graph to the inside graph + auto graph_prev = graph_context_.top(); for (size_t i = 0; i < args.size(); i++) { auto input = GetInput(args[i], false); inputs.push_back(input); } - auto out_cnode = curr_g_->NewCNode(inputs); - set_pyobj(curr_g_, GetCellId(cell, args)); + auto out_cnode = graph_prev->NewCNode(inputs); + set_pyobj(graph_prev, GetCellId(cell, args)); if (py::isinstance(out)) { auto out_list = py::cast(out); auto out_size = static_cast(out_list.size()); for (int i = 0; i < out_size; i++) { - set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); + set_obj_node_map(graph_prev, GetId(out_list[i]), out_cnode, i); SetTupleOutput(out_list[i], out_cnode, std::vector{i}); } } - set_obj_node_map(curr_g_, GetId(out), out_cnode); + set_obj_node_map(graph_prev, GetId(out), out_cnode); } else { parse::ResolveFuncGraph(newfg, resource_); resource_->set_func_graph(newfg); + Popp(); } } @@ -1348,14 +1358,36 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args) { MS_LOG(INFO) << "GradNet start" << args.size(); - std::size_t size = args.size(); std::string cell_id = GetCellId(cell, args); if (graph_map_.count(cell_id) != 0) { MS_LOG(DEBUG) << "GradNet already compiled"; return; } + size_t forward_args_count = args.size(); + if (grad->sens_param()) { + forward_args_count = forward_args_count - 1; + } + py::tuple forward_args(forward_args_count); + for (size_t i = 0; i < forward_args_count; i++) { + forward_args[i] = args[i]; + } + std::string forward_cell_id = GetCellId(cell, forward_args); + MS_LOG(DEBUG) << "Forward cell_id:" << forward_cell_id; + if (df_builder_map_.find(forward_cell_id) == df_builder_map_.end()) { + MS_LOG(EXCEPTION) << "Cannot find df builder"; + } + df_builder_ = df_builder_map_[forward_cell_id]; + if (df_builder_ == nullptr) { + MS_LOG(EXCEPTION) << "Got unexpected null df builder"; + } + + if (cell_resource_map_.find(forward_cell_id) == cell_resource_map_.end()) { + MS_LOG(EXCEPTION) << "Cannot find resource for " << forward_cell_id; + } MS_LOG(DEBUG) << "GradNet first compiled"; + resource_ = cell_resource_map_[forward_cell_id]; + std::vector new_params; for (size_t i = 0; i < size; i++) { ParameterPtr p = std::make_shared(df_builder_); @@ -1368,6 +1400,10 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje std::vector w_args = GetWeightsArgs(weights); MS_EXCEPTION_IF_NULL(resource_->func_graph()); + if (cell_graph_map_.find(forward_cell_id) == cell_graph_map_.end()) { + MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id; + } + top_g_ = cell_graph_map_[forward_cell_id]; auto g = GradGraph(resource_->func_graph(), grad, w_args, size); resource_->set_func_graph(g); resource_->manager()->KeepRoots({g}); @@ -1409,6 +1445,7 @@ void PynativeExecutor::Clear(const std::string &flag) { MapClear>(&graph_map_, flag); MapClear>(&cell_graph_map_, flag); MapClear>(&cell_resource_map_, flag); + MapClear>(&df_builder_map_, flag); Clean(); // Maybe exit in the pynative runing op, so need reset pynative flag. auto ms_context = MsContext::GetInstance(); @@ -1431,7 +1468,7 @@ void PynativeExecutor::Clear(const std::string &flag) { graph_info_map_.clear(); op_id_map_.clear(); obj_to_forward_id_.clear(); - std::stack().swap(graph_p_); + std::stack().swap(graph_context_); ConfigManager::GetInstance().ResetIterNum(); } @@ -1509,7 +1546,6 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) } std::string backend = MsContext::GetInstance()->backend_policy(); - MS_LOG(DEBUG) << "Eval run" << backend; BaseRef value = (*run)(arg_list); MS_LOG(DEBUG) << "Run end" << value.ToString(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 726e60033cd..01af466a63b 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -155,7 +155,9 @@ class PynativeExecutor : public std::enable_shared_from_this { std::unordered_map op_id_map_; std::unordered_map obj_to_forward_id_; std::unordered_map node_abs_map_; - std::stack graph_p_; + std::unordered_map df_builder_map_; + // the stack that records the context of graph created, the bottom is the top graph + std::stack graph_context_; FuncGraphPtr top_g_; FuncGraphPtr df_builder_; FuncGraphPtr curr_g_; diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 99c37c6988d..0ba8cff2687 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -21,7 +21,7 @@ from types import FunctionType from mindspore import context from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ - TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ + TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ from ...common import dtype as mstype from ...common.api import ms_function, _pynative_exec, _wrap_func from .. import functional as F @@ -475,6 +475,7 @@ class _ListAppend(ListAppend_): Args: name (str): The name of the metafuncgraph object. """ + def __init__(self, name): ListAppend_.__init__(self, name) diff --git a/tests/ut/python/pynative_mode/test_multi_grad.py b/tests/ut/python/pynative_mode/test_multi_grad.py new file mode 100644 index 00000000000..a59bc9f3bed --- /dev/null +++ b/tests/ut/python/pynative_mode/test_multi_grad.py @@ -0,0 +1,207 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore import context, nn, Tensor, Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C + + +def setup_module(): + context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) + + +class _Grad(nn.Cell): + def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): + super().__init__() + self.network = network + self.grad = grad + self.sens_param = self.grad.sens_param + self.wrt_params = wrt_params + self.real_inputs_count = real_inputs_count + if self.wrt_params: + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + if self.wrt_params: + if self.real_inputs_count is None or self.sens_param is False: + return self.grad(self.network, self.params)(*inputs) + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) + + if self.real_inputs_count is None or self.sens_param is False: + return self.grad(self.network)(*inputs) + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + return self.grad(self.network)(*real_inputs, sense_param_inputs) + + +class GradOfFirstInput(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=C.GradOperation(sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class GradOfAllInputs(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +def test_multi_grad(): + class ForwardNetMul(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + a = x * x + b = y * y + return a * b + + class ForwardNetAdd(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + a = x + x + x + b = y + y + return a * b + mulnet = ForwardNetMul() + addnet = ForwardNetAdd() + x = Tensor(np.ones([32]), dtype=mstype.float32) + y = Tensor(np.ones([32])*2, dtype=mstype.float32) + sens = Tensor(np.ones([32]), dtype=mstype.float32) + mulnet.set_grad() + addnet.set_grad() + out1 = mulnet(x, y) + out2 = addnet(x, y) + grad_mul = GradOfAllInputs(mulnet) + grad_add = GradOfAllInputs(addnet) + grad_mul(x, y, sens) + grad_add(x, y, sens) + + +def test_multi_same_grad(): + class ForwardNetMul(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + a = x * x + b = y * y + return a * b + + class ForwardNetAdd(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + a = x*3 + b = y*2 + return a + b + mulnet = ForwardNetMul() + addnet = ForwardNetAdd() + x = Tensor(np.ones([32]), dtype=mstype.float32) + y = Tensor(np.ones([32]), dtype=mstype.float32) + sens = Tensor(np.ones([32]), dtype=mstype.float32) + mulnet.set_grad() + addnet.set_grad() + out1 = mulnet(x, y) + out2 = addnet(x, y) + grad_mul = GradOfAllInputs(mulnet) + grad_add = GradOfFirstInput(mulnet) + grad_mul(x, y, sens) + grad_add(x, y, sens) + + +def test_net_inner_grad(): + class ForwardNetMul(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + a = x * x + b = y * y + return a * b + + class ForwardNetAdd(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + + def construct(self, x, y): + a = x + x + b = y + y + res = self.net(a, b) + return res + mulnet = ForwardNetMul() + addnet = ForwardNetAdd(mulnet) + x = Tensor(np.ones([32]), dtype=mstype.float32) + y = Tensor(np.ones([32]), dtype=mstype.float32) + sens = Tensor(np.ones([32]), dtype=mstype.float32) + mulnet.set_grad() + addnet.set_grad() + out1 = mulnet(x, y) + out2 = addnet(x, y) + grad_mul = GradOfAllInputs(addnet) + grad_add = GradOfAllInputs(mulnet) + grad_mul(x, y, sens) + grad_add(x, y, sens) + + +def test_net_inner_first_run_grad(): + class ForwardNetMul(nn.Cell): + def __init__(self): + super().__init__() + self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1') + + def construct(self, x, y): + a = x * self.z1 + b = y * y + return a * b + + class ForwardNetAdd(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2') + self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2') + + def construct(self, x, y): + a = x + x*self.z3 + b = y + y*self.z2 + res = self.net(a, b) + return res + mulnet = ForwardNetMul() + addnet = ForwardNetAdd(mulnet) + x = Tensor(np.ones([32]), dtype=mstype.float32) + y = Tensor(np.ones([32]), dtype=mstype.float32) + sens = Tensor(np.ones([32]), dtype=mstype.float32) + mulnet.set_grad() + addnet.set_grad() + out1 = mulnet(x, y) + out2 = addnet(x, y) + grad_mul = GradOfAllInputs(addnet) + grad_add = GradOfFirstInput(mulnet) + grad_mul(x, y, sens) + grad_add(x, y, sens)