forked from mindspore-Ecosystem/mindspore
!1228 Adapt tbe op UnsortedSegmentMin for GE.
Merge pull request !1228 from liuxiao/UnsortedSegmentMin
This commit is contained in:
commit
18c9495000
|
@ -138,6 +138,7 @@ const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
|||
const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
|
||||
const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
|
||||
const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
|
||||
const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
|
||||
const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
|
||||
const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
|
||||
|
|
|
@ -143,6 +143,7 @@ extern const PrimitivePtr kPrimSize;
|
|||
extern const PrimitivePtr kPrimArgMax;
|
||||
extern const PrimitivePtr kPrimPack;
|
||||
extern const PrimitivePtr kPrimUnpack;
|
||||
extern const PrimitivePtr kPrimUnsortedSegmentMin;
|
||||
extern const PrimitivePtr kPrimUnsortedSegmentSum;
|
||||
extern const PrimitivePtr kPrimConcatOffset;
|
||||
extern const PrimitivePtr kPrimReshape;
|
||||
|
|
|
@ -341,6 +341,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
|
||||
{prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)},
|
||||
{string(kNameStridedSlice), ADPT_DESC(StridedSlice)},
|
||||
{prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)},
|
||||
{prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)},
|
||||
{string(kNameExpandDims), ADPT_DESC(ExpandDims)},
|
||||
{prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)},
|
||||
|
|
|
@ -1053,6 +1053,12 @@ INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int
|
|||
ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// UnsortedSegmentMin
|
||||
INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
|
||||
INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
|
||||
ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// ExpandDims
|
||||
INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}};
|
||||
ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP;
|
||||
|
|
|
@ -283,6 +283,9 @@ DECLARE_OP_USE_OUTPUT(StridedSlice)
|
|||
DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentMinD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD)
|
||||
DECLARE_OP_ADAPTER(ExpandDims)
|
||||
DECLARE_OP_USE_OUTPUT(ExpandDims)
|
||||
DECLARE_OP_ADAPTER(Squeeze)
|
||||
|
|
|
@ -22,6 +22,7 @@ from .. import functional as F
|
|||
from .grad_base import bprop_getters
|
||||
from ..primitive import constexpr
|
||||
from ... import context
|
||||
from ...common import dtype as mstype
|
||||
|
||||
reduce_sum = P.ReduceSum()
|
||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
|
@ -29,6 +30,7 @@ transpose = P.Transpose()
|
|||
shape_op = P.Shape()
|
||||
reshape = P.Reshape()
|
||||
invert_permutation = P.InvertPermutation()
|
||||
logical_and = P.LogicalAnd()
|
||||
|
||||
|
||||
@bprop_getters.register(P.Fill)
|
||||
|
@ -456,6 +458,57 @@ def get_bprop_diag_part(self):
|
|||
return bprop
|
||||
|
||||
|
||||
def _GatherDropNegatives(params,
|
||||
ids,
|
||||
zero_clipped_indices=None,
|
||||
is_positive=None):
|
||||
"""Helper function for unsorted segment ops."""
|
||||
maximum = P.Maximum()
|
||||
gather = P.GatherV2()
|
||||
greater_equal = P.GreaterEqual()
|
||||
rank = P.Rank()
|
||||
fill = P.Fill()
|
||||
select = P.Select()
|
||||
|
||||
if zero_clipped_indices is None:
|
||||
zero_clipped_indices = maximum(ids, zeros_like(ids))
|
||||
gathered = gather(params, zero_clipped_indices, 0)
|
||||
if is_positive is None:
|
||||
is_positive = greater_equal(ids, 0)
|
||||
is_positive_shape = shape_op(is_positive)
|
||||
broadcastable_shape = is_positive_shape
|
||||
for _ in range(rank(gathered) - rank(is_positive)):
|
||||
broadcastable_shape += (1,)
|
||||
is_positive = reshape(is_positive, broadcastable_shape)
|
||||
gathered_shape = shape_op(gathered)
|
||||
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
|
||||
zero_slice = zeros_like(gathered)
|
||||
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
|
||||
|
||||
|
||||
@bprop_getters.register(P.UnsortedSegmentMin)
|
||||
def get_bprop_unsorted_segment_min(self):
|
||||
"""Generate bprop for UnsortedSegmentMin"""
|
||||
equal = P.Equal()
|
||||
cast = P.Cast()
|
||||
divide = P.RealDiv()
|
||||
get_dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
def bprop(x, segment_ids, num_segments, out, dout):
|
||||
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
|
||||
is_selected = equal(x, gathered_outputs)
|
||||
is_selected = logical_and(is_selected, is_positive)
|
||||
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
|
||||
segment_ids, num_segments)
|
||||
weighted_grads = divide(dout, num_selected)
|
||||
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
|
||||
zero_clipped_indices, is_positive)
|
||||
zeros = zeros_like(gathered_grads)
|
||||
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SpaceToBatch)
|
||||
def get_bprop_space_to_batch(self):
|
||||
"""Generate bprop for SpaceToBatch"""
|
||||
|
|
|
@ -28,7 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Squeeze, StridedSlice, Tile,
|
||||
Transpose, TruncatedNormal, TupleToArray,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
|
@ -96,6 +96,7 @@ __all__ = [
|
|||
'MaxPool',
|
||||
'TopK',
|
||||
'Adam',
|
||||
'Softplus',
|
||||
'Softmax',
|
||||
'LogSoftmax',
|
||||
'SoftmaxCrossEntropyWithLogits',
|
||||
|
@ -210,6 +211,7 @@ __all__ = [
|
|||
'Size',
|
||||
'DepthwiseConv2dNative',
|
||||
'UnsortedSegmentSum',
|
||||
'UnsortedSegmentMin',
|
||||
"AllGather",
|
||||
"AllReduce",
|
||||
"ReduceScatter",
|
||||
|
|
|
@ -1253,6 +1253,54 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class UnsortedSegmentMin(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the minimum along segments of a tensor.
|
||||
|
||||
If the given segment_ids is negative, the value will be ignored.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`.
|
||||
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
|
||||
|
||||
Outputs:
|
||||
Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
|
||||
>>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32)
|
||||
>>> num_segments = 2
|
||||
>>> unsorted_segment_min = P.UnsortedSegmentMin()
|
||||
>>> unsorted_segment_min(input_x, segment_ids, num_segments)
|
||||
[[1., 2., 3.], [4., 2., 1.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init UnsortedSegmentMin"""
|
||||
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, segment_ids, num_segments):
|
||||
x_type = x['dtype']
|
||||
x_shape = x['shape']
|
||||
segment_ids_shape = segment_ids['shape']
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32]
|
||||
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
|
||||
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
|
||||
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
|
||||
num_segments_v = num_segments['value']
|
||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
|
||||
segment_ids_shape_len = len(segment_ids_shape)
|
||||
out_shape = [num_segments_v]
|
||||
out_shape += x_shape[segment_ids_shape_len:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': x_type,
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class Concat(PrimitiveWithInfer):
|
||||
r"""
|
||||
Concat tensor in specified axis.
|
||||
|
|
|
@ -778,6 +778,11 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
|
||||
'desc_bprop': [[4, 1, 3]],
|
||||
'skip': ['backward']}),
|
||||
('UnsortedSegmentMin', {
|
||||
'block': P.UnsortedSegmentMin(),
|
||||
'desc_const': [4],
|
||||
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||
'desc_bprop': [[4, 2, 1, 3]]}),
|
||||
('DropoutGenMask', {
|
||||
'block': P.DropoutGenMask(),
|
||||
'desc_const': [(2, 2), Tensor(0.5, mstype.float32)],
|
||||
|
|
Loading…
Reference in New Issue