!35610 [kernel]fix apply proximal adagrad vmap bug

Merge pull request !35610 from 张学同/applyproximaladagrad_gpu
This commit is contained in:
i-robot 2022-06-09 06:36:08 +00:00 committed by Gitee
commit ccc681eec8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 4 deletions

View File

@ -49,8 +49,8 @@ def get_apply_proximal_adagrad_rule(prim, axis_size):
ValueError("The source axis of `var` is None, but the source "
"axis of `accum/lr/l1/l2/grad` is not None. The execution order of "
"operator `{}` cannot be guaranteed.".format(prim_name))
out = prim(var, accum, lr, l1, l2, grad, u_monad)
return (out, None)
var, accum = prim(var, accum, lr, l1, l2, grad, u_monad)
return (var, None), (accum, None)
if var_dim != 0 or accum_dim != var_dim:
raise ValueError("For `{}`, the source axis of `var` must be equal to `accum`, and not equal to 0, "
@ -61,8 +61,8 @@ def get_apply_proximal_adagrad_rule(prim, axis_size):
l2 = _bdim_at_front(l2, l2_dim, axis_size)
grad = _bdim_at_front(grad, grad_dim, axis_size)
out = prim(var, accum, lr, l1, l2, grad, u_monad)
return (out, 0)
var, accum = prim(var, accum, lr, l1, l2, grad, u_monad)
return (var, 0), (accum, 0)
return vmap_rule