forked from mindspore-Ecosystem/mindspore
!2268 add bprop for ScatterMax
Merge pull request !2268 from yanzhenxiang2020/open_ScatterMax_bprop
This commit is contained in:
commit
002029ff12
|
@ -496,6 +496,17 @@ def get_bprop_tensor_scatter_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ScatterMax)
|
||||
def get_bprop_scatter_max(self):
|
||||
"""Generate bprop for ScatterMax"""
|
||||
gather = P.GatherV2()
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
return dout, zeros_like(indices), gather(dout, indices, 0)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Argmax)
|
||||
def get_bprop_argmax(self):
|
||||
"""Generate bprop for Argmax"""
|
||||
|
|
Loading…
Reference in New Issue