!3711 fix topK multi dimention grad func

Merge pull request !3711 from fangzehua/topkgrad
This commit is contained in:
mindspore-ci-bot 2020-07-30 14:50:06 +08:00 committed by Gitee
commit 57ce3e5dfc
1 changed files with 48 additions and 4 deletions

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Define the grad rules of neural network related operations."""
import math
import numpy as np
from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr
@ -628,19 +629,62 @@ def get_bprop_onehot(self):
return bprop
@constexpr
def _range_op(start, limit, delta, dtype):
"""helper function for Grad TopK"""
range_op = inner.Range(float(start), float(limit), float(delta))
length_input = math.ceil((limit - start) / delta)
input_tensor = Tensor(list(range(length_input)), dtype)
range_out = range_op(input_tensor)
return range_out
@constexpr
def _get_1d_shape(in_shape):
"""helper function for Grad TopK"""
out_shape = 1
for i in in_shape:
out_shape *= i
return (out_shape,)
@bprop_getters.register(P.TopK)
def get_bprop_top_kv2(self):
"""Grad definition for `TopK` operation."""
scatter = P.ScatterNd()
expand_dims = P.ExpandDims()
shape_op = P.Shape()
reshape_op = P.Reshape()
dtype = P.DType()
def bprop(input_x, k, out, dout):
# (n1, n2, ...., n_p), in_lastdim = n_p
in_shape = shape_op(input_x)
in_lastdim = in_shape[-1]
# (n_1, ... n_(p-1), k), ind_lastdim = k
indices = out[1]
indices = expand_dims(indices, -1)
updates = dout[0]
shapes = shape_op(input_x)
return scatter(indices, updates, shapes), zeros_like(k)
ind_shape = shape_op(indices)
ind_lastdim = ind_shape[-1]
# (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1)
ind_2d = reshape_op(indices, (-1, ind_lastdim))
outerdim = shape_op(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
indices_dtype = dtype(indices)
range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
# expand_dims to (k, 1), then broadcast
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
in_shape_1d = _get_1d_shape(in_shape)
out_grad = reshape_op(
scatter(
expand_dims(ind, -1),
reshape_op(dout[0], (-1,)),
in_shape_1d),
in_shape)
return out_grad, zeros_like(k)
return bprop