From 6f64fffdb4302f2c365308289744f65e27472212 Mon Sep 17 00:00:00 2001 From: danishnxt Date: Thu, 26 Nov 2020 18:50:36 -0500 Subject: [PATCH] Adding API fix for handling dynamic_shape correctly plus STs for SegSum dyn_shape Updating InferImpl function adding check to inferImpl function lint lint2 --- mindspore/core/abstract/prim_arrays.cc | 69 ++++++--- mindspore/ops/operations/array_ops.py | 4 +- tests/st/ops/gpu/test_unsorted_segment_sum.py | 138 +++++++++++++++++- 3 files changed, 185 insertions(+), 26 deletions(-) diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index fa90104f296..75612552253 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -205,24 +205,23 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 3); - // input x auto x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x->shape()); - auto x_shape = x->shape()->shape(); - // segment_ids auto segment_ids = CheckArg(op_name, args_spec_list, 1); MS_EXCEPTION_IF_NULL(segment_ids); MS_EXCEPTION_IF_NULL(segment_ids->shape()); auto segment_ids_shape = segment_ids->shape()->shape(); - // checks on Tensors 0 and 1 types - (void)CheckTensorDType(x, {kFloat32, kInt32}, "Input 0 (x) for SequenceMask should be %s"); - (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for SequenceMask should be %s"); + (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s"); + (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentSum should be %s"); + // check if dynamic shape + bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); + bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); + bool op_is_dynamic = x_is_dyn && ids_is_dyn; + auto x_shape = x->shape()->shape(); ShapeVector shape; - ShapeVector max_shape; - ShapeVector min_shape; - int64_t num_segments_value; - if (args_spec_list[2]->isa()) { // Num segments is Tensor + int64_t num_segments_value = 0; + if (args_spec_list[2]->isa()) { // num_segments is Tensor auto num_segments = args_spec_list[2]->cast(); MS_EXCEPTION_IF_NULL(num_segments); auto num_segments_value_ptr = num_segments->BuildValue(); @@ -230,26 +229,48 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri auto num_segments_tensor = num_segments_value_ptr->cast(); MS_EXCEPTION_IF_NULL(num_segments_tensor); num_segments_value = *static_cast(num_segments_tensor->data_c()); - shape.emplace_back(num_segments_value); - } else if (args_spec_list[2]->isa()) { // Num segments is Scalar + } else if (args_spec_list[2]->isa()) { // num_segments is Scalar auto num_segments = CheckArg(op_name, args_spec_list, 2); num_segments_value = GetValue(num_segments->BuildValue()); - shape.emplace_back(num_segments_value); } else { MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentSum"; } - shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); - // calc max shape - if (!x->shape()->max_shape().empty()) { // copy max shape from x if present - std::copy(x->shape()->max_shape().begin(), x->shape()->max_shape().end(), std::back_inserter(max_shape)); - } else { // copy x shape directly if not present - std::copy(x->shape()->shape().begin(), x->shape()->shape().end(), std::back_inserter(max_shape)); + if (num_segments_value <= 0) { + MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentSum"; } - // calc min shape - min_shape.push_back(segment_ids_shape.size()); - std::copy(x->shape()->shape().begin() + segment_ids_shape.size(), x->shape()->shape().end(), - back_inserter(min_shape)); - // return shape, min shape, max shape + shape.emplace_back(num_segments_value); + shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); + // dims check + if (!op_is_dynamic) { + for (size_t i = 0; i < segment_ids_shape.size(); i++) { + if (x_shape[i] != segment_ids_shape[i]) { + MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values"; + } + } + return std::make_shared(x->element(), std::make_shared(shape)); + } + // is dynamic + ShapeVector min_shape; + ShapeVector max_shape; + min_shape.emplace_back(num_segments_value); + max_shape.emplace_back(num_segments_value); + // only run validation if shape values are known + bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); + bool ids_any_shape = + std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); + if (!x_any_shape && !ids_any_shape) { + for (size_t i = 0; i < segment_ids_shape.size(); i++) { + if (x_shape[i] != segment_ids_shape[i]) { + MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values"; + } + } + } + ShapeVector x_shape_min; + ShapeVector x_shape_max; + x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape(); + x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape(); + min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end()); + max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end()); return std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); } diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 927d12aea4e..73edf36a64f 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1904,7 +1904,9 @@ class UnsortedSegmentSum(PrimitiveWithInfer): shp += x_shp[segment_ids_shp_len:] if 'max_shape' in x: - output_max_shape = x['max_shape'] + output_incoming = x['max_shape'] + output_max_shape = [num_segments_v] + output_max_shape += output_incoming[segment_ids_shp_len:] else: output_max_shape = x_shp out = {'shape': shp, diff --git a/tests/st/ops/gpu/test_unsorted_segment_sum.py b/tests/st/ops/gpu/test_unsorted_segment_sum.py index fc5662ebc0e..7d57e77748b 100644 --- a/tests/st/ops/gpu/test_unsorted_segment_sum.py +++ b/tests/st/ops/gpu/test_unsorted_segment_sum.py @@ -20,11 +20,11 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.common import dtype as mstype +from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import operations as P context.set_context(device_target='GPU') - class UnsortedSegmentSumNet(nn.Cell): def __init__(self, num_segments): super(UnsortedSegmentSumNet, self).__init__() @@ -108,3 +108,139 @@ def test_3D(): [0., 0., 0.], [0., 0., 0.]]] assert (output.asnumpy() == expect).all() + + +# Testing Dynamic Shape +class UnsortedSegmentSumDynNet(nn.Cell): + def __init__(self, num_segments): + super(UnsortedSegmentSumDynNet, self).__init__() + self.unsorted_segment_sum = P.UnsortedSegmentSum() + self.to_dyn_op = inner.GpuConvertToDynamicShape() + self.num_segments = num_segments + + def construct(self, data, ids): + data_dyn = self.to_dyn_op(data) + ids_dyn = self.to_dyn_op(ids) + return self.unsorted_segment_sum(data_dyn, ids_dyn, self.num_segments) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dyn(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + num_segments = 4 + net = UnsortedSegmentSumDynNet(num_segments) + + input_x = Tensor([1, 2, 3, 4], mstype.float32) + segment_ids = Tensor([0, 0, 1, 2], mstype.int32) + output = net(input_x, segment_ids) + expect = [3, 3, 4, 0] + assert (output.asnumpy() == expect).all() + + input_x = Tensor([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], mstype.float32) + segment_ids = Tensor([2, 1, 1], mstype.int32) + output = net(input_x, segment_ids) + expect = [[0, 0, 0, 0], + [14, 16, 18, 20], + [1, 2, 3, 4], + [0, 0, 0, 0]] + assert (output.asnumpy() == expect).all() + + input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3)) + segment_ids = Tensor([2, 1, 1, -1], mstype.int32) + output = net(input_x, segment_ids) + expect = [[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]], + + [[45., 47., 49.], + [51., 53., 55.], + [57., 59., 61.], + [63., 65., 67.], + [69., 71., 73.]], + + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.], + [9., 10., 11.], + [12., 13., 14.]], + + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dyn_1(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + num_segments = 6 + net = UnsortedSegmentSumDynNet(num_segments) + + input_x = Tensor([1, 2, 3, 4], mstype.float32) + segment_ids = Tensor([0, 0, 1, 2], mstype.int32) + output = net(input_x, segment_ids) + expect = [3, 3, 4, 0, 0, 0] + assert (output.asnumpy() == expect).all() + + input_x = Tensor([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], mstype.float32) + segment_ids = Tensor([2, 1, 1], mstype.int32) + output = net(input_x, segment_ids) + expect = [[0, 0, 0, 0], + [14, 16, 18, 20], + [1, 2, 3, 4], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]] + assert (output.asnumpy() == expect).all() + + input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3)) + segment_ids = Tensor([2, 1, 1, -1], mstype.int32) + output = net(input_x, segment_ids) + expect = [[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]], + + [[45., 47., 49.], + [51., 53., 55.], + [57., 59., 61.], + [63., 65., 67.], + [69., 71., 73.]], + + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.], + [9., 10., 11.], + [12., 13., 14.]], + + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]], + + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]], + + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]] + assert (output.asnumpy() == expect).all()