From 34cc178bd095310ce9765275ae44ce1cf998c5e6 Mon Sep 17 00:00:00 2001 From: danishnxt Date: Thu, 29 Oct 2020 16:33:03 -0400 Subject: [PATCH] New New UnsortedSegmentMax for GPU [API][CUDA_KERNEL] PyLintFix header fix --- .../arrays/unsorted_segment_max_gpu_kernel.cc | 36 ++++ .../arrays/unsorted_segment_max_gpu_kernel.h | 123 +++++++++++ .../gpu/cuda_impl/unsorted_segment_max.cu | 79 +++++++ .../gpu/cuda_impl/unsorted_segment_max.cuh | 30 +++ .../pass/const_input_to_attr_registry.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/array_ops.py | 50 +++++ tests/st/ops/gpu/test_unsorted_segment_max.py | 204 ++++++++++++++++++ 9 files changed, 526 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh create mode 100644 tests/st/ops/gpu/test_unsorted_segment_max.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc new file mode 100644 index 00000000000..1707e1d03d4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + UnsortedSegmentMax, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMaxGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE( + UnsortedSegmentMax, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMaxGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + UnsortedSegmentMax, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMaxGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h new file mode 100644 index 00000000000..0b2fd088272 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh" + +namespace mindspore { +namespace kernel { +template +class UnsortedSegmentMaxGpuKernel : public GpuKernel { + public: + UnsortedSegmentMaxGpuKernel() + : num_segments_(1), + inner_size_(1), + outer_size_(1), + input_size_(1), + segment_ids_size_(1), + output_size_(1), + is_null_input_(false) {} + ~UnsortedSegmentMaxGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + int *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(output_addr, std::numeric_limits::min(), outputs[0]->size, + reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + CalUnsortedSegmentMax(input_addr, indices_addr, num_segments_, outer_size_, inner_size_, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shapes); + if (is_null_input_) { + MS_LOG(WARNING) << "UnsortedSegmentMax input is null"; + InitSizeLists(); + return true; + } + auto segment_ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); // we get that from computation + + num_segments_ = output_shapes[0]; + + input_size_ = 1; + for (size_t i = 0; i < input_shapes.size(); i++) { + input_size_ *= input_shapes[i]; + } + + segment_ids_size_ = 1; + for (size_t i = 0; i < segment_ids_shapes.size(); i++) { + segment_ids_size_ *= segment_ids_shapes[i]; + } + + output_size_ = 1; + for (size_t i = 0; i < output_shapes.size(); i++) { + output_size_ *= output_shapes[i]; + } + + outer_size_ = input_shapes[0]; + inner_size_ = 1; + for (size_t i = 1; i < input_shapes.size(); i++) { + inner_size_ *= input_shapes[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(segment_ids_size_ * sizeof(int)); + output_size_list_.push_back(output_size_ * sizeof(T)); + } + + private: + int num_segments_; + size_t inner_size_; + size_t outer_size_; + size_t input_size_; + size_t segment_ids_size_; + size_t output_size_; + bool is_null_input_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu new file mode 100644 index 00000000000..954611ffe77 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh" +#include + +template +__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, + size_t inner_size, bool fp16_flag, T init_K, T *output) { + if (fp16_flag) { + init_K = __int2half_rd(-65504); // min value representable by float16 + } + + for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; + t_idx += blockDim.x * gridDim.x) { + int segment_id = t_idx / KWARPSIZE / inner_size; + int inner_id = t_idx / KWARPSIZE % inner_size; + int lane_id = threadIdx.x % KWARPSIZE; + T threadK = init_K; + + for (int i = lane_id; i < outer_size; i += KWARPSIZE) { + if (segment_ids[i] != segment_id) continue; + T other_K = input[i * inner_size + inner_id]; + if (threadK < other_K) { + threadK = other_K; + } + } + __syncwarp(); + + for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) { + T other_K = __shfl_down_sync(0xffffffff, threadK, offset); + if (threadK < other_K) { + threadK = other_K; + } + } + + __syncwarp(); + + if (lane_id == 0) { + output[segment_id * inner_size + inner_id] = threadK; + } + __syncthreads(); + } +} + +template +void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, + size_t inner_size, T *output, cudaStream_t stream) { + int size = (inner_size * KWARPSIZE * num_segments); + bool fp16_flag = false; + // handle fp16 min value + if (std::is_same::value) { + fp16_flag = true; + } + T init_K = std::numeric_limits::lowest(); + UnsortedSegmentMax<<>>(input, segment_ids, num_segments, outer_size, + inner_size, fp16_flag, init_K, output); + return; +} + +template void CalUnsortedSegmentMax(const float *input, const int *segment_ids, const int num_segments, + size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const half *input, const int *segment_ids, const int num_segments, + size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const int *input, const int *segment_ids, const int num_segments, + size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh new file mode 100644 index 00000000000..de80b74007f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MAX_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MAX_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +// Setting warp size to sync data across threads +#define KWARPSIZE 32 + +template +void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, + size_t inner_size, T *output, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index 5ef4991e4c2..a64741034a9 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -49,6 +49,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimReduceAll->name(), {1}); Register(prim::kPrimReduceAny->name(), {1}); Register(prim::kPrimUnsortedSegmentMin->name(), {2}); + Register(prim::kPrimUnsortedSegmentMax->name(), {2}); Register(kSparseGatherV2, {2}); Register(kUnsortedSegmentProdOpName, {2}); Register(kSimpleMeanGradOpName, {1}); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index d02996c1b2b..d94ea28e18c 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -92,6 +92,7 @@ inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("Size"); inline const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); inline const PrimitivePtr kPrimPack = std::make_shared("Pack"); +inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared("UnsortedSegmentMax"); inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); inline const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1e9a0b5713d..69e6e3a63f4 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -30,8 +30,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort, - Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, - UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, + Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, + UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 24c3784afa3..a1df0dfb509 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1601,6 +1601,56 @@ class UnsortedSegmentMin(PrimitiveWithInfer): return out +class UnsortedSegmentMax(PrimitiveWithInfer): + """ + Computes the maximum along segments of a tensor. + + Inputs: + - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. + The data type must be float16, float32 or int32. + - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be >= 0. + The data type must be int32. + - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`. + + Outputs: + Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`. + + Examples: + >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) + >>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32)) + >>> num_segments = 2 + >>> unsorted_segment_max = P.UnsortedSegmentMax() + >>> unsorted_segment_max(input_x, segment_ids, num_segments) + [[1., 2., 3.], [4., 5., 6.]] + """ + + @prim_attr_register + def __init__(self): + """Initialize UnsortedSegmentMax""" + self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) + + def __infer__(self, x, segment_ids, num_segments): + x_type = x['dtype'] + x_shape = x['shape'] + segment_ids_shape = segment_ids['shape'] + valid_type = [mstype.float16, mstype.float32, mstype.int32] + validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) + validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) + validator.check(f'first shape of input_x', x_shape[0], + 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) + num_segments_v = num_segments['value'] + validator.check_value_type('num_segments', num_segments_v, [int], self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) + segment_ids_shape_len = len(segment_ids_shape) + out_shape = [num_segments_v] + out_shape += x_shape[segment_ids_shape_len:] + out = {'shape': out_shape, + 'dtype': x_type, + 'value': None} + return out + + class UnsortedSegmentProd(PrimitiveWithInfer): """ Computes the product along segments of a tensor. diff --git a/tests/st/ops/gpu/test_unsorted_segment_max.py b/tests/st/ops/gpu/test_unsorted_segment_max.py new file mode 100644 index 00000000000..90fb7bcc713 --- /dev/null +++ b/tests/st/ops/gpu/test_unsorted_segment_max.py @@ -0,0 +1,204 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +class UnsortedSegmentMaxNet(nn.Cell): + def __init__(self, num_segments): + super(UnsortedSegmentMaxNet, self).__init__() + self.unsorted_segment_max = P.UnsortedSegmentMax() + self.num_segments = num_segments + + def construct(self, data, ids): + return self.unsorted_segment_max(data, ids, self.num_segments) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_1d_int32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_x = Tensor([1, 2, 3, 4], mstype.int32) + segment_ids = Tensor([0, 0, 1, 2], mstype.int32) + num_segments = 4 + net = UnsortedSegmentMaxNet(num_segments) + output = net(input_x, segment_ids) + expect = [2, 3, 4, -2147483648] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_2d_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_x = Tensor([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], mstype.int32) + segment_ids = Tensor([2, 1, 1], mstype.int32) + num_segments = 4 + net = UnsortedSegmentMaxNet(num_segments) + output = net(input_x, segment_ids) + expect = [[[-2147483648, -2147483648, -2147483648, -2147483648], + [9, 10, 11, 12], + [1, 2, 3, 4], + [-2147483648, -2147483648, -2147483648, -2147483648]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_float16(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_x = Tensor(np.arange( + 4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16) + segment_ids = Tensor([2, 1, 1, -1], mstype.int32) + num_segments = 5 + net = UnsortedSegmentMaxNet(num_segments) + output = net(input_x, segment_ids).asnumpy() + expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04]], + [[3.00e+01, 3.10e+01, 3.20e+01], + [3.30e+01, 3.40e+01, 3.50e+01], + [3.60e+01, 3.70e+01, 3.80e+01], + [3.90e+01, 4.00e+01, 4.10e+01], + [4.20e+01, 4.30e+01, 4.40e+01]], + [[0.00e+00, 1.00e+00, 2.00e+00], + [3.00e+00, 4.00e+00, 5.00e+00], + [6.00e+00, 7.00e+00, 8.00e+00], + [9.00e+00, 1.00e+01, 1.10e+01], + [1.20e+01, 1.30e+01, 1.40e+01]], + [[-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04]], + [[-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04], + [-6.55e+04, -6.55e+04, -6.55e+04]]]).astype(np.float16) + np.testing.assert_array_almost_equal(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_float32(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_x = Tensor(np.arange( + 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) + segment_ids = Tensor([2, 1, 1, -1], mstype.int32) + num_segments = 3 + net = UnsortedSegmentMaxNet(num_segments) + output = net(input_x, segment_ids).asnumpy() + expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], + [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], + [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], + [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], + [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], + [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], + [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], + [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], + [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], + [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) + np.testing.assert_array_almost_equal(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_single_init(): + context.set_context(device_target='GPU') + input_x = Tensor(np.arange( + 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) + segment_ids = Tensor([3, 0, 1, -1], mstype.int32) + net = P.UnsortedSegmentMax() + + num_segments = 4 + output = net(input_x, segment_ids, num_segments).asnumpy() + expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], + [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], + [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], + [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], + [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], + [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], + [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], + [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], + [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], + [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], + [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], + [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], + [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], + [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) + np.testing.assert_array_almost_equal(output, expect) + + num_segments = 6 + output = net(input_x, segment_ids, num_segments).asnumpy() + expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], + [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], + [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], + [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], + [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], + [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], + [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], + [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], + [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], + [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], + [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], + [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], + [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], + [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], + [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], + [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32) + np.testing.assert_array_almost_equal(output, expect)