From 287bcdd78386718d46345d234f3ab5908da86ef9 Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Thu, 2 Dec 2021 17:04:27 +0800 Subject: [PATCH] fix grad twice position bug --- .../frontend/operator/composite/composite.h | 3 +++ .../pipeline/pynative/pynative_execute.cc | 7 ++++++- .../ccsrc/pipeline/pynative/pynative_execute.h | 1 + mindspore/common/api.py | 3 +++ mindspore/ops/composite/base.py | 5 +++-- mindspore/ops/functional.py | 2 ++ tests/st/gradient/test_grad_graph.py | 18 ++++++++++++++++++ tests/st/gradient/test_grad_pynative.py | 18 ++++++++++++++++++ 8 files changed, 54 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index 5444663915d..5df4fb3039f 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -153,11 +153,14 @@ class GradOperation : public MetaFuncGraph { const std::vector &weight_args = {}); FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + + void set_grad_position(const std::string &grad_position) { grad_position_ = grad_position; } bool sens_param() const { return sens_param_; } bool get_all_; bool get_by_list_; bool sens_param_; bool get_by_position_; + std::string grad_position_; private: void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 62f38099883..adaef343b82 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -2844,7 +2844,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con bool forward_run = false; // Get cell id and input args info const auto &cell_id = GetCellId(cell, args); - grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_); + grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_) + grad->grad_position_; std::string input_args_id; for (size_t i = 0; i < args.size(); ++i) { @@ -3269,6 +3269,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & return grad_executor()->CheckGraph(cell, args); } +void PynativeExecutor::set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position) { + grad->set_grad_position(std::string(py::str(grad_position))); +} + py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args) { return grad_executor()->CheckAlreadyRun(grad, cell, args); @@ -3420,6 +3424,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") .def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase") .def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag") + .def("set_grad_position", &PynativeExecutor::set_grad_position, "set pynative grad position") .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), "Executor set grad flag.") .def("set_py_exe_path", &PynativeExecutor::set_py_exe_path, diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index d8ad78fa0f9..1a050834b1a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -399,6 +399,7 @@ class PynativeExecutor : public std::enable_shared_from_this { py::object GradMsFunction(const py::object &out, const py::args &args); py::object CheckGraph(const py::object &cell, const py::args &args); py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args); + void set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position); py::object Run(const py::object &cell, const py::tuple &args); // Used by graph clean diff --git a/mindspore/common/api.py b/mindspore/common/api.py index c2921c7b4f2..5215c7182fc 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -482,6 +482,9 @@ class _PynativeExecutor: def check_run(self, grad, obj, *args, **kwargs): return self._executor.check_run(grad, obj, *args, *(kwargs.values())) + def set_grad_position(self, grad, grad_position): + return self._executor.set_grad_position(grad, grad_position) + def grad(self, grad, obj, weights, grad_position, *args, **kwargs): self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values())) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 8b843aae659..72db5b5d9f9 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -424,7 +424,7 @@ class _Grad(GradOperation_): self.pynative_ = False self.grad_position = None - def _pynative_forward_run(self, grad, args, kwargs, fn): + def _pynative_forward_run(self, grad, args, kwargs, fn, grad_position): """ Pynative forward run to build grad graph. """ new_kwargs = kwargs if self.sens_param: @@ -469,11 +469,12 @@ class _Grad(GradOperation_): def after_grad(*args): return grad_(fn)(*args) elif self.pynative_: + _pynative_executor.set_grad_position(grad_, grad_position) @_wrap_func def after_grad(*args, **kwargs): if _pynative_executor.check_graph(fn, *args, **kwargs): print("Another grad step is running") - self._pynative_forward_run(grad_, args, kwargs, fn) + self._pynative_forward_run(grad_, args, kwargs, fn, grad_position) _pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs) out = _pynative_executor(fn, *args, **kwargs) _pynative_executor.clear_grad(fn, *args, **kwargs) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 164f38954e3..8a97c0fcb78 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -265,6 +265,7 @@ def jvp(fn, inputs, v): [28. 49.]] """ jvp_inner = _JvpInner() + @ms_function def _wrap_container(*arg): args = arg[1:] @@ -318,6 +319,7 @@ def vjp(fn, inputs, v): [ 1.00000000e+00, 1.00000000e+00]])) """ vjp_inner = _VjpInner() + @ms_function def wrap_container(*arg): args = arg[:-1] diff --git a/tests/st/gradient/test_grad_graph.py b/tests/st/gradient/test_grad_graph.py index fa50ac09e40..294e375466e 100644 --- a/tests/st/gradient/test_grad_graph.py +++ b/tests/st/gradient/test_grad_graph.py @@ -183,3 +183,21 @@ def test_grad_warp_with_msfunction_graph(): expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32)) real_grad = grad_warp_with_msfunction(x, y, z) assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_with_grad_position_twice_graph(): + """ + Features: Function grad. + Description: Test F.grad with function setting grad_position twice in graph mode. + Expectation: No exception. + """ + x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) + net = MultipleInputsSingleOutputNet() + out1 = grad(net, grad_position=0)(x, y, z) + out2 = grad(net, grad_position=(0, 1))(x, y, z) + assert isinstance(out1, Tensor) + assert isinstance(out2, tuple) diff --git a/tests/st/gradient/test_grad_pynative.py b/tests/st/gradient/test_grad_pynative.py index 437a77dac6a..2dc47b827bb 100644 --- a/tests/st/gradient/test_grad_pynative.py +++ b/tests/st/gradient/test_grad_pynative.py @@ -184,3 +184,21 @@ def test_grad_warp_with_msfunction_pynative(): expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32)) real_grad = grad_warp_with_msfunction(x, y, z) assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_with_grad_position_twice_pynative(): + """ + Features: Function grad. + Description: Test F.grad with function setting grad_position twice in pynative mode. + Expectation: No exception. + """ + x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) + z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) + net = MultipleInputsSingleOutputNet() + out1 = grad(net, grad_position=0)(x, y, z) + out2 = grad(net, grad_position=(0, 1))(x, y, z) + assert isinstance(out1, Tensor) + assert isinstance(out2, tuple)