forked from mindspore-Ecosystem/mindspore
!1553 change hook function grad input to tuple
Merge pull request !1553 from wangqiuliang/r0.3
This commit is contained in:
commit
431bc8bf4b
|
@ -624,8 +624,8 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
if (_hook_grad.find(cell_id) != _hook_grad.end()) {
|
||||
py::tuple hook_args = py::tuple(3);
|
||||
hook_args[0] = cell_id;
|
||||
hook_args[1] = _hook_grad[cell_id];
|
||||
hook_args[2] = py_args[2];
|
||||
hook_args[1] = py::make_tuple(_hook_grad[cell_id]);
|
||||
hook_args[2] = py::make_tuple(py_args[2]);
|
||||
py::function fn_hook = prim->hook();
|
||||
obj = fn_hook(*hook_args);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
|
@ -638,7 +638,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
}
|
||||
} else {
|
||||
py::function fn_hook = prim->hook();
|
||||
obj = fn_hook(py_args[2]);
|
||||
obj = fn_hook(py::make_tuple(py_args[2]));
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
|
|
|
@ -30,13 +30,13 @@ def weight_variable():
|
|||
|
||||
def cell_hook_function(cell_id, grad_input, grad_output):
|
||||
print(cell_id)
|
||||
assert(grad_output.asnumpy().shape == (32, 6, 14, 14))
|
||||
assert(grad_input.asnumpy().shape == (32, 16, 10, 10))
|
||||
assert(grad_output[0].asnumpy().shape == (32, 6, 14, 14))
|
||||
assert(grad_input[0].asnumpy().shape == (32, 16, 10, 10))
|
||||
|
||||
|
||||
def var_hook_function(grad_out):
|
||||
print("grad:", grad_out)
|
||||
assert(grad_out.asnumpy().shape == (32, 120))
|
||||
assert(grad_out[0].asnumpy().shape == (32, 120))
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue