forked from mindspore-Ecosystem/mindspore
!35610 [kernel]fix apply proximal adagrad vmap bug
Merge pull request !35610 from 张学同/applyproximaladagrad_gpu
This commit is contained in:
commit
ccc681eec8
|
@ -49,8 +49,8 @@ def get_apply_proximal_adagrad_rule(prim, axis_size):
|
|||
ValueError("The source axis of `var` is None, but the source "
|
||||
"axis of `accum/lr/l1/l2/grad` is not None. The execution order of "
|
||||
"operator `{}` cannot be guaranteed.".format(prim_name))
|
||||
out = prim(var, accum, lr, l1, l2, grad, u_monad)
|
||||
return (out, None)
|
||||
var, accum = prim(var, accum, lr, l1, l2, grad, u_monad)
|
||||
return (var, None), (accum, None)
|
||||
|
||||
if var_dim != 0 or accum_dim != var_dim:
|
||||
raise ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
|
||||
|
@ -61,8 +61,8 @@ def get_apply_proximal_adagrad_rule(prim, axis_size):
|
|||
l2 = _bdim_at_front(l2, l2_dim, axis_size)
|
||||
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
||||
|
||||
out = prim(var, accum, lr, l1, l2, grad, u_monad)
|
||||
return (out, 0)
|
||||
var, accum = prim(var, accum, lr, l1, l2, grad, u_monad)
|
||||
return (var, 0), (accum, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
|
Loading…
Reference in New Issue