UnsortedSegmentMin for GE

This commit is contained in:
liuxiao 2020-05-14 21:36:35 +08:00
parent 699d0c1082
commit cc024bb3a1
9 changed files with 121 additions and 1 deletions

View File

@ -138,6 +138,7 @@ const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");

View File

@ -143,6 +143,7 @@ extern const PrimitivePtr kPrimSize;
extern const PrimitivePtr kPrimArgMax;
extern const PrimitivePtr kPrimPack;
extern const PrimitivePtr kPrimUnpack;
extern const PrimitivePtr kPrimUnsortedSegmentMin;
extern const PrimitivePtr kPrimUnsortedSegmentSum;
extern const PrimitivePtr kPrimConcatOffset;
extern const PrimitivePtr kPrimReshape;

View File

@ -340,6 +340,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
{prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)},
{string(kNameStridedSlice), ADPT_DESC(StridedSlice)},
{prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)},
{prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)},
{string(kNameExpandDims), ADPT_DESC(ExpandDims)},
{prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)},

View File

@ -1048,6 +1048,12 @@ INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int
ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}};
// UnsortedSegmentMin
INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}};
// ExpandDims
INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}};
ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP;

View File

@ -281,6 +281,9 @@ DECLARE_OP_USE_OUTPUT(StridedSlice)
DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
DECLARE_OP_ADAPTER(UnsortedSegmentMinD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD)
DECLARE_OP_ADAPTER(ExpandDims)
DECLARE_OP_USE_OUTPUT(ExpandDims)
DECLARE_OP_ADAPTER(Squeeze)

View File

@ -22,6 +22,7 @@ from .. import functional as F
from .grad_base import bprop_getters
from ..primitive import constexpr
from ... import context
from ...common import dtype as mstype
reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum()
@ -29,6 +30,7 @@ transpose = P.Transpose()
shape_op = P.Shape()
reshape = P.Reshape()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
@bprop_getters.register(P.Fill)
@ -456,6 +458,57 @@ def get_bprop_diag_part(self):
return bprop
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
"""Helper function for unsorted segment ops."""
maximum = P.Maximum()
gather = P.GatherV2()
greater_equal = P.GreaterEqual()
rank = P.Rank()
fill = P.Fill()
select = P.Select()
if zero_clipped_indices is None:
zero_clipped_indices = maximum(ids, zeros_like(ids))
gathered = gather(params, zero_clipped_indices, 0)
if is_positive is None:
is_positive = greater_equal(ids, 0)
is_positive_shape = shape_op(is_positive)
broadcastable_shape = is_positive_shape
for _ in range(rank(gathered) - rank(is_positive)):
broadcastable_shape += (1,)
is_positive = reshape(is_positive, broadcastable_shape)
gathered_shape = shape_op(gathered)
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
zero_slice = zeros_like(gathered)
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
@bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin"""
equal = P.Equal()
cast = P.Cast()
divide = P.RealDiv()
get_dtype = P.DType()
select = P.Select()
def bprop(x, segment_ids, num_segments, out, dout):
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
is_selected = equal(x, gathered_outputs)
is_selected = logical_and(is_selected, is_positive)
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
segment_ids, num_segments)
weighted_grads = divide(dout, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices, is_positive)
zeros = zeros_like(gathered_grads)
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@bprop_getters.register(P.SpaceToBatch)
def get_bprop_space_to_batch(self):
"""Generate bprop for SpaceToBatch"""

View File

@ -28,7 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
@ -96,6 +96,7 @@ __all__ = [
'MaxPool',
'TopK',
'Adam',
'Softplus',
'Softmax',
'LogSoftmax',
'SoftmaxCrossEntropyWithLogits',
@ -210,6 +211,7 @@ __all__ = [
'Size',
'DepthwiseConv2dNative',
'UnsortedSegmentSum',
'UnsortedSegmentMin',
"AllGather",
"AllReduce",
"ReduceScatter",

View File

@ -1253,6 +1253,54 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
return out
class UnsortedSegmentMin(PrimitiveWithInfer):
"""
Computes the minimum along segments of a tensor.
If the given segment_ids is negative, the value will be ignored.
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
Outputs:
Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32)
>>> num_segments = 2
>>> unsorted_segment_min = P.UnsortedSegmentMin()
>>> unsorted_segment_min(input_x, segment_ids, num_segments)
[[1., 2., 3.], [4., 2., 1.]]
"""
@prim_attr_register
def __init__(self):
"""init UnsortedSegmentMin"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:]
out = {'shape': out_shape,
'dtype': x_type,
'value': None}
return out
class Concat(PrimitiveWithInfer):
r"""
Concat tensor in specified axis.

View File

@ -773,6 +773,11 @@ test_case_nn_ops = [
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]],
'skip': ['backward']}),
('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[4, 2, 1, 3]]}),
('DropoutGenMask', {
'block': P.DropoutGenMask(),
'desc_const': [(2, 2), Tensor(0.5, mstype.float32)],