!27152 [bug]fix functional grad with grad_position twice bug

Merge pull request !27152 from chenzhuo/grad
This commit is contained in:
i-robot 2021-12-06 10:56:46 +00:00 committed by Gitee
commit aeb2cccbe6
8 changed files with 54 additions and 3 deletions

View File

@ -153,11 +153,14 @@ class GradOperation : public MetaFuncGraph {
const std::vector<AnfNodePtr> &weight_args = {}); const std::vector<AnfNodePtr> &weight_args = {});
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
void set_grad_position(const std::string &grad_position) { grad_position_ = grad_position; }
bool sens_param() const { return sens_param_; } bool sens_param() const { return sens_param_; }
bool get_all_; bool get_all_;
bool get_by_list_; bool get_by_list_;
bool sens_param_; bool sens_param_;
bool get_by_position_; bool get_by_position_;
std::string grad_position_;
private: private:
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,

View File

@ -2857,7 +2857,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
bool forward_run = false; bool forward_run = false;
// Get cell id and input args info // Get cell id and input args info
const auto &cell_id = GetCellId(cell, args); const auto &cell_id = GetCellId(cell, args);
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_); grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_) + grad->grad_position_;
std::string input_args_id; std::string input_args_id;
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
@ -3282,6 +3282,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
return grad_executor()->CheckGraph(cell, args); return grad_executor()->CheckGraph(cell, args);
} }
void PynativeExecutor::set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position) {
grad->set_grad_position(std::string(py::str(grad_position)));
}
py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
const py::args &args) { const py::args &args) {
return grad_executor()->CheckAlreadyRun(grad, cell, args); return grad_executor()->CheckAlreadyRun(grad, cell, args);
@ -3433,6 +3437,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase") .def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag") .def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
.def("set_grad_position", &PynativeExecutor::set_grad_position, "set pynative grad position")
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag.") "Executor set grad flag.")
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path, .def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,

View File

@ -399,6 +399,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
py::object GradMsFunction(const py::object &out, const py::args &args); py::object GradMsFunction(const py::object &out, const py::args &args);
py::object CheckGraph(const py::object &cell, const py::args &args); py::object CheckGraph(const py::object &cell, const py::args &args);
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args); py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
void set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position);
py::object Run(const py::object &cell, const py::tuple &args); py::object Run(const py::object &cell, const py::tuple &args);
// Used by graph clean // Used by graph clean

View File

@ -485,6 +485,9 @@ class _PynativeExecutor:
def check_run(self, grad, obj, *args, **kwargs): def check_run(self, grad, obj, *args, **kwargs):
return self._executor.check_run(grad, obj, *args, *(kwargs.values())) return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
def set_grad_position(self, grad, grad_position):
return self._executor.set_grad_position(grad, grad_position)
def grad(self, grad, obj, weights, grad_position, *args, **kwargs): def grad(self, grad, obj, weights, grad_position, *args, **kwargs):
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values())) self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))

View File

@ -424,7 +424,7 @@ class _Grad(GradOperation_):
self.pynative_ = False self.pynative_ = False
self.grad_position = None self.grad_position = None
def _pynative_forward_run(self, grad, args, kwargs, fn): def _pynative_forward_run(self, grad, args, kwargs, fn, grad_position):
""" Pynative forward run to build grad graph. """ """ Pynative forward run to build grad graph. """
new_kwargs = kwargs new_kwargs = kwargs
if self.sens_param: if self.sens_param:
@ -469,11 +469,12 @@ class _Grad(GradOperation_):
def after_grad(*args): def after_grad(*args):
return grad_(fn)(*args) return grad_(fn)(*args)
elif self.pynative_: elif self.pynative_:
_pynative_executor.set_grad_position(grad_, grad_position)
@_wrap_func @_wrap_func
def after_grad(*args, **kwargs): def after_grad(*args, **kwargs):
if _pynative_executor.check_graph(fn, *args, **kwargs): if _pynative_executor.check_graph(fn, *args, **kwargs):
print("Another grad step is running") print("Another grad step is running")
self._pynative_forward_run(grad_, args, kwargs, fn) self._pynative_forward_run(grad_, args, kwargs, fn, grad_position)
_pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs) _pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs)
out = _pynative_executor(fn, *args, **kwargs) out = _pynative_executor(fn, *args, **kwargs)
_pynative_executor.clear_grad(fn, *args, **kwargs) _pynative_executor.clear_grad(fn, *args, **kwargs)

View File

@ -266,6 +266,7 @@ def jvp(fn, inputs, v):
[28. 49.]] [28. 49.]]
""" """
jvp_inner = _JvpInner() jvp_inner = _JvpInner()
@ms_function @ms_function
def _wrap_container(*arg): def _wrap_container(*arg):
args = arg[1:] args = arg[1:]
@ -320,6 +321,7 @@ def vjp(fn, inputs, v):
[ 1.00000000e+00, 1.00000000e+00]])) [ 1.00000000e+00, 1.00000000e+00]]))
""" """
vjp_inner = _VjpInner() vjp_inner = _VjpInner()
@ms_function @ms_function
def wrap_container(*arg): def wrap_container(*arg):
args = arg[:-1] args = arg[:-1]

View File

@ -183,3 +183,21 @@ def test_grad_warp_with_msfunction_graph():
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32)) expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_warp_with_msfunction(x, y, z) real_grad = grad_warp_with_msfunction(x, y, z)
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_with_grad_position_twice_graph():
"""
Features: Function grad.
Description: Test F.grad with function setting grad_position twice in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsSingleOutputNet()
out1 = grad(net, grad_position=0)(x, y, z)
out2 = grad(net, grad_position=(0, 1))(x, y, z)
assert isinstance(out1, Tensor)
assert isinstance(out2, tuple)

View File

@ -184,3 +184,21 @@ def test_grad_warp_with_msfunction_pynative():
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32)) expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_warp_with_msfunction(x, y, z) real_grad = grad_warp_with_msfunction(x, y, z)
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy()) assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_with_grad_position_twice_pynative():
"""
Features: Function grad.
Description: Test F.grad with function setting grad_position twice in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsSingleOutputNet()
out1 = grad(net, grad_position=0)(x, y, z)
out2 = grad(net, grad_position=(0, 1))(x, y, z)
assert isinstance(out1, Tensor)
assert isinstance(out2, tuple)