forked from OSSInnovation/mindspore
!3711 fix topK multi dimention grad func
Merge pull request !3711 from fangzehua/topkgrad
This commit is contained in:
commit
57ce3e5dfc
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Define the grad rules of neural network related operations."""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.ops import _selected_grad_ops as SG
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -628,19 +629,62 @@ def get_bprop_onehot(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
def _range_op(start, limit, delta, dtype):
|
||||
"""helper function for Grad TopK"""
|
||||
range_op = inner.Range(float(start), float(limit), float(delta))
|
||||
length_input = math.ceil((limit - start) / delta)
|
||||
input_tensor = Tensor(list(range(length_input)), dtype)
|
||||
range_out = range_op(input_tensor)
|
||||
return range_out
|
||||
|
||||
@constexpr
|
||||
def _get_1d_shape(in_shape):
|
||||
"""helper function for Grad TopK"""
|
||||
out_shape = 1
|
||||
for i in in_shape:
|
||||
out_shape *= i
|
||||
return (out_shape,)
|
||||
|
||||
@bprop_getters.register(P.TopK)
|
||||
def get_bprop_top_kv2(self):
|
||||
"""Grad definition for `TopK` operation."""
|
||||
scatter = P.ScatterNd()
|
||||
expand_dims = P.ExpandDims()
|
||||
shape_op = P.Shape()
|
||||
reshape_op = P.Reshape()
|
||||
dtype = P.DType()
|
||||
|
||||
def bprop(input_x, k, out, dout):
|
||||
|
||||
# (n1, n2, ...., n_p), in_lastdim = n_p
|
||||
in_shape = shape_op(input_x)
|
||||
in_lastdim = in_shape[-1]
|
||||
|
||||
# (n_1, ... n_(p-1), k), ind_lastdim = k
|
||||
indices = out[1]
|
||||
indices = expand_dims(indices, -1)
|
||||
updates = dout[0]
|
||||
shapes = shape_op(input_x)
|
||||
return scatter(indices, updates, shapes), zeros_like(k)
|
||||
ind_shape = shape_op(indices)
|
||||
ind_lastdim = ind_shape[-1]
|
||||
|
||||
# (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1)
|
||||
ind_2d = reshape_op(indices, (-1, ind_lastdim))
|
||||
outerdim = shape_op(ind_2d)[0]
|
||||
|
||||
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
||||
indices_dtype = dtype(indices)
|
||||
range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
|
||||
|
||||
# expand_dims to (k, 1), then broadcast
|
||||
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
|
||||
in_shape_1d = _get_1d_shape(in_shape)
|
||||
|
||||
out_grad = reshape_op(
|
||||
scatter(
|
||||
expand_dims(ind, -1),
|
||||
reshape_op(dout[0], (-1,)),
|
||||
in_shape_1d),
|
||||
in_shape)
|
||||
return out_grad, zeros_like(k)
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
Loading…
Reference in New Issue