forked from mindspore-Ecosystem/mindspore
!27152 [bug]fix functional grad with grad_position twice bug
Merge pull request !27152 from chenzhuo/grad
This commit is contained in:
commit
aeb2cccbe6
|
@ -153,11 +153,14 @@ class GradOperation : public MetaFuncGraph {
|
||||||
const std::vector<AnfNodePtr> &weight_args = {});
|
const std::vector<AnfNodePtr> &weight_args = {});
|
||||||
|
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
|
|
||||||
|
void set_grad_position(const std::string &grad_position) { grad_position_ = grad_position; }
|
||||||
bool sens_param() const { return sens_param_; }
|
bool sens_param() const { return sens_param_; }
|
||||||
bool get_all_;
|
bool get_all_;
|
||||||
bool get_by_list_;
|
bool get_by_list_;
|
||||||
bool sens_param_;
|
bool sens_param_;
|
||||||
bool get_by_position_;
|
bool get_by_position_;
|
||||||
|
std::string grad_position_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||||
|
|
|
@ -2857,7 +2857,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
|
||||||
bool forward_run = false;
|
bool forward_run = false;
|
||||||
// Get cell id and input args info
|
// Get cell id and input args info
|
||||||
const auto &cell_id = GetCellId(cell, args);
|
const auto &cell_id = GetCellId(cell, args);
|
||||||
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_);
|
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_) + grad->grad_position_;
|
||||||
|
|
||||||
std::string input_args_id;
|
std::string input_args_id;
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
|
@ -3282,6 +3282,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
|
||||||
return grad_executor()->CheckGraph(cell, args);
|
return grad_executor()->CheckGraph(cell, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position) {
|
||||||
|
grad->set_grad_position(std::string(py::str(grad_position)));
|
||||||
|
}
|
||||||
|
|
||||||
py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
|
py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
|
||||||
const py::args &args) {
|
const py::args &args) {
|
||||||
return grad_executor()->CheckAlreadyRun(grad, cell, args);
|
return grad_executor()->CheckAlreadyRun(grad, cell, args);
|
||||||
|
@ -3433,6 +3437,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
||||||
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
||||||
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
|
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
|
||||||
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
|
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
|
||||||
|
.def("set_grad_position", &PynativeExecutor::set_grad_position, "set pynative grad position")
|
||||||
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
||||||
"Executor set grad flag.")
|
"Executor set grad flag.")
|
||||||
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
|
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
|
||||||
|
|
|
@ -399,6 +399,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||||
py::object GradMsFunction(const py::object &out, const py::args &args);
|
py::object GradMsFunction(const py::object &out, const py::args &args);
|
||||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
|
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
|
||||||
|
void set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position);
|
||||||
py::object Run(const py::object &cell, const py::tuple &args);
|
py::object Run(const py::object &cell, const py::tuple &args);
|
||||||
|
|
||||||
// Used by graph clean
|
// Used by graph clean
|
||||||
|
|
|
@ -485,6 +485,9 @@ class _PynativeExecutor:
|
||||||
def check_run(self, grad, obj, *args, **kwargs):
|
def check_run(self, grad, obj, *args, **kwargs):
|
||||||
return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
|
return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
|
||||||
|
|
||||||
|
def set_grad_position(self, grad, grad_position):
|
||||||
|
return self._executor.set_grad_position(grad, grad_position)
|
||||||
|
|
||||||
def grad(self, grad, obj, weights, grad_position, *args, **kwargs):
|
def grad(self, grad, obj, weights, grad_position, *args, **kwargs):
|
||||||
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
||||||
|
|
||||||
|
|
|
@ -424,7 +424,7 @@ class _Grad(GradOperation_):
|
||||||
self.pynative_ = False
|
self.pynative_ = False
|
||||||
self.grad_position = None
|
self.grad_position = None
|
||||||
|
|
||||||
def _pynative_forward_run(self, grad, args, kwargs, fn):
|
def _pynative_forward_run(self, grad, args, kwargs, fn, grad_position):
|
||||||
""" Pynative forward run to build grad graph. """
|
""" Pynative forward run to build grad graph. """
|
||||||
new_kwargs = kwargs
|
new_kwargs = kwargs
|
||||||
if self.sens_param:
|
if self.sens_param:
|
||||||
|
@ -469,11 +469,12 @@ class _Grad(GradOperation_):
|
||||||
def after_grad(*args):
|
def after_grad(*args):
|
||||||
return grad_(fn)(*args)
|
return grad_(fn)(*args)
|
||||||
elif self.pynative_:
|
elif self.pynative_:
|
||||||
|
_pynative_executor.set_grad_position(grad_, grad_position)
|
||||||
@_wrap_func
|
@_wrap_func
|
||||||
def after_grad(*args, **kwargs):
|
def after_grad(*args, **kwargs):
|
||||||
if _pynative_executor.check_graph(fn, *args, **kwargs):
|
if _pynative_executor.check_graph(fn, *args, **kwargs):
|
||||||
print("Another grad step is running")
|
print("Another grad step is running")
|
||||||
self._pynative_forward_run(grad_, args, kwargs, fn)
|
self._pynative_forward_run(grad_, args, kwargs, fn, grad_position)
|
||||||
_pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs)
|
_pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs)
|
||||||
out = _pynative_executor(fn, *args, **kwargs)
|
out = _pynative_executor(fn, *args, **kwargs)
|
||||||
_pynative_executor.clear_grad(fn, *args, **kwargs)
|
_pynative_executor.clear_grad(fn, *args, **kwargs)
|
||||||
|
|
|
@ -266,6 +266,7 @@ def jvp(fn, inputs, v):
|
||||||
[28. 49.]]
|
[28. 49.]]
|
||||||
"""
|
"""
|
||||||
jvp_inner = _JvpInner()
|
jvp_inner = _JvpInner()
|
||||||
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def _wrap_container(*arg):
|
def _wrap_container(*arg):
|
||||||
args = arg[1:]
|
args = arg[1:]
|
||||||
|
@ -320,6 +321,7 @@ def vjp(fn, inputs, v):
|
||||||
[ 1.00000000e+00, 1.00000000e+00]]))
|
[ 1.00000000e+00, 1.00000000e+00]]))
|
||||||
"""
|
"""
|
||||||
vjp_inner = _VjpInner()
|
vjp_inner = _VjpInner()
|
||||||
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def wrap_container(*arg):
|
def wrap_container(*arg):
|
||||||
args = arg[:-1]
|
args = arg[:-1]
|
||||||
|
|
|
@ -183,3 +183,21 @@ def test_grad_warp_with_msfunction_graph():
|
||||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_with_grad_position_twice_graph():
|
||||||
|
"""
|
||||||
|
Features: Function grad.
|
||||||
|
Description: Test F.grad with function setting grad_position twice in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = MultipleInputsSingleOutputNet()
|
||||||
|
out1 = grad(net, grad_position=0)(x, y, z)
|
||||||
|
out2 = grad(net, grad_position=(0, 1))(x, y, z)
|
||||||
|
assert isinstance(out1, Tensor)
|
||||||
|
assert isinstance(out2, tuple)
|
||||||
|
|
|
@ -184,3 +184,21 @@ def test_grad_warp_with_msfunction_pynative():
|
||||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_with_grad_position_twice_pynative():
|
||||||
|
"""
|
||||||
|
Features: Function grad.
|
||||||
|
Description: Test F.grad with function setting grad_position twice in pynative mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = MultipleInputsSingleOutputNet()
|
||||||
|
out1 = grad(net, grad_position=0)(x, y, z)
|
||||||
|
out2 = grad(net, grad_position=(0, 1))(x, y, z)
|
||||||
|
assert isinstance(out1, Tensor)
|
||||||
|
assert isinstance(out2, tuple)
|
||||||
|
|
Loading…
Reference in New Issue