!27152 [bug]fix functional grad with grad_position twice bug
Merge pull request !27152 from chenzhuo/grad
This commit is contained in:
commit
aeb2cccbe6
|
@ -153,11 +153,14 @@ class GradOperation : public MetaFuncGraph {
|
|||
const std::vector<AnfNodePtr> &weight_args = {});
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
|
||||
void set_grad_position(const std::string &grad_position) { grad_position_ = grad_position; }
|
||||
bool sens_param() const { return sens_param_; }
|
||||
bool get_all_;
|
||||
bool get_by_list_;
|
||||
bool sens_param_;
|
||||
bool get_by_position_;
|
||||
std::string grad_position_;
|
||||
|
||||
private:
|
||||
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||
|
|
|
@ -2857,7 +2857,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
|
|||
bool forward_run = false;
|
||||
// Get cell id and input args info
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_);
|
||||
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_) + grad->grad_position_;
|
||||
|
||||
std::string input_args_id;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
|
@ -3282,6 +3282,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
|
|||
return grad_executor()->CheckGraph(cell, args);
|
||||
}
|
||||
|
||||
void PynativeExecutor::set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position) {
|
||||
grad->set_grad_position(std::string(py::str(grad_position)));
|
||||
}
|
||||
|
||||
py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const py::args &args) {
|
||||
return grad_executor()->CheckAlreadyRun(grad, cell, args);
|
||||
|
@ -3433,6 +3437,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|||
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
||||
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
|
||||
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
|
||||
.def("set_grad_position", &PynativeExecutor::set_grad_position, "set pynative grad position")
|
||||
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
||||
"Executor set grad flag.")
|
||||
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
|
||||
|
|
|
@ -399,6 +399,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
py::object GradMsFunction(const py::object &out, const py::args &args);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
|
||||
void set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position);
|
||||
py::object Run(const py::object &cell, const py::tuple &args);
|
||||
|
||||
// Used by graph clean
|
||||
|
|
|
@ -485,6 +485,9 @@ class _PynativeExecutor:
|
|||
def check_run(self, grad, obj, *args, **kwargs):
|
||||
return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
|
||||
|
||||
def set_grad_position(self, grad, grad_position):
|
||||
return self._executor.set_grad_position(grad, grad_position)
|
||||
|
||||
def grad(self, grad, obj, weights, grad_position, *args, **kwargs):
|
||||
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
||||
|
||||
|
|
|
@ -424,7 +424,7 @@ class _Grad(GradOperation_):
|
|||
self.pynative_ = False
|
||||
self.grad_position = None
|
||||
|
||||
def _pynative_forward_run(self, grad, args, kwargs, fn):
|
||||
def _pynative_forward_run(self, grad, args, kwargs, fn, grad_position):
|
||||
""" Pynative forward run to build grad graph. """
|
||||
new_kwargs = kwargs
|
||||
if self.sens_param:
|
||||
|
@ -469,11 +469,12 @@ class _Grad(GradOperation_):
|
|||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
elif self.pynative_:
|
||||
_pynative_executor.set_grad_position(grad_, grad_position)
|
||||
@_wrap_func
|
||||
def after_grad(*args, **kwargs):
|
||||
if _pynative_executor.check_graph(fn, *args, **kwargs):
|
||||
print("Another grad step is running")
|
||||
self._pynative_forward_run(grad_, args, kwargs, fn)
|
||||
self._pynative_forward_run(grad_, args, kwargs, fn, grad_position)
|
||||
_pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs)
|
||||
out = _pynative_executor(fn, *args, **kwargs)
|
||||
_pynative_executor.clear_grad(fn, *args, **kwargs)
|
||||
|
|
|
@ -266,6 +266,7 @@ def jvp(fn, inputs, v):
|
|||
[28. 49.]]
|
||||
"""
|
||||
jvp_inner = _JvpInner()
|
||||
|
||||
@ms_function
|
||||
def _wrap_container(*arg):
|
||||
args = arg[1:]
|
||||
|
@ -320,6 +321,7 @@ def vjp(fn, inputs, v):
|
|||
[ 1.00000000e+00, 1.00000000e+00]]))
|
||||
"""
|
||||
vjp_inner = _VjpInner()
|
||||
|
||||
@ms_function
|
||||
def wrap_container(*arg):
|
||||
args = arg[:-1]
|
||||
|
|
|
@ -183,3 +183,21 @@ def test_grad_warp_with_msfunction_graph():
|
|||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_with_grad_position_twice_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function setting grad_position twice in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsSingleOutputNet()
|
||||
out1 = grad(net, grad_position=0)(x, y, z)
|
||||
out2 = grad(net, grad_position=(0, 1))(x, y, z)
|
||||
assert isinstance(out1, Tensor)
|
||||
assert isinstance(out2, tuple)
|
||||
|
|
|
@ -184,3 +184,21 @@ def test_grad_warp_with_msfunction_pynative():
|
|||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_with_grad_position_twice_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function setting grad_position twice in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsSingleOutputNet()
|
||||
out1 = grad(net, grad_position=0)(x, y, z)
|
||||
out2 = grad(net, grad_position=(0, 1))(x, y, z)
|
||||
assert isinstance(out1, Tensor)
|
||||
assert isinstance(out2, tuple)
|
||||
|
|
Loading…
Reference in New Issue