!8006 [MS][GPU][CUDA][API] Adding new Ops - UnsortedSegmentMax GPU

Merge pull request !8006 from danishnxt/SegMax
This commit is contained in:
mindspore-ci-bot 2020-10-31 05:48:50 +08:00 committed by Gitee
commit de20d3e488
9 changed files with 526 additions and 2 deletions

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMax,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentMaxGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMax,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
UnsortedSegmentMaxGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMax,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentMaxGpuKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,123 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_
#include <vector>
#include <limits>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class UnsortedSegmentMaxGpuKernel : public GpuKernel {
public:
UnsortedSegmentMaxGpuKernel()
: num_segments_(1),
inner_size_(1),
outer_size_(1),
input_size_(1),
segment_ids_size_(1),
output_size_(1),
is_null_input_(false) {}
~UnsortedSegmentMaxGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
int *indices_addr = GetDeviceAddress<int>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(output_addr, std::numeric_limits<T>::min(), outputs[0]->size,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemSet Failed");
CalUnsortedSegmentMax(input_addr, indices_addr, num_segments_, outer_size_, inner_size_, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
if (is_null_input_) {
MS_LOG(WARNING) << "UnsortedSegmentMax input is null";
InitSizeLists();
return true;
}
auto segment_ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); // we get that from computation
num_segments_ = output_shapes[0];
input_size_ = 1;
for (size_t i = 0; i < input_shapes.size(); i++) {
input_size_ *= input_shapes[i];
}
segment_ids_size_ = 1;
for (size_t i = 0; i < segment_ids_shapes.size(); i++) {
segment_ids_size_ *= segment_ids_shapes[i];
}
output_size_ = 1;
for (size_t i = 0; i < output_shapes.size(); i++) {
output_size_ *= output_shapes[i];
}
outer_size_ = input_shapes[0];
inner_size_ = 1;
for (size_t i = 1; i < input_shapes.size(); i++) {
inner_size_ *= input_shapes[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(segment_ids_size_ * sizeof(int));
output_size_list_.push_back(output_size_ * sizeof(T));
}
private:
int num_segments_;
size_t inner_size_;
size_t outer_size_;
size_t input_size_;
size_t segment_ids_size_;
size_t output_size_;
bool is_null_input_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh"
#include <limits>
template <typename T>
__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
size_t inner_size, bool fp16_flag, T init_K, T *output) {
if (fp16_flag) {
init_K = __int2half_rd(-65504); // min value representable by float16
}
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size;
t_idx += blockDim.x * gridDim.x) {
int segment_id = t_idx / KWARPSIZE / inner_size;
int inner_id = t_idx / KWARPSIZE % inner_size;
int lane_id = threadIdx.x % KWARPSIZE;
T threadK = init_K;
for (int i = lane_id; i < outer_size; i += KWARPSIZE) {
if (segment_ids[i] != segment_id) continue;
T other_K = input[i * inner_size + inner_id];
if (threadK < other_K) {
threadK = other_K;
}
}
__syncwarp();
for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) {
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
if (threadK < other_K) {
threadK = other_K;
}
}
__syncwarp();
if (lane_id == 0) {
output[segment_id * inner_size + inner_id] = threadK;
}
__syncthreads();
}
}
template <typename T>
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
size_t inner_size, T *output, cudaStream_t stream) {
int size = (inner_size * KWARPSIZE * num_segments);
bool fp16_flag = false;
// handle fp16 min value
if (std::is_same<T, half>::value) {
fp16_flag = true;
}
T init_K = std::numeric_limits<T>::lowest();
UnsortedSegmentMax<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, segment_ids, num_segments, outer_size,
inner_size, fp16_flag, init_K, output);
return;
}
template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int num_segments,
size_t outer_size, size_t inner_size, float *output, cudaStream_t stream);
template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int num_segments,
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int num_segments,
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MAX_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MAX_H_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
// Setting warp size to sync data across threads
#define KWARPSIZE 32
template <typename T>
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
size_t inner_size, T *output, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_

View File

@ -49,6 +49,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimReduceAll->name(), {1});
Register(prim::kPrimReduceAny->name(), {1});
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
Register(kSparseGatherV2, {2});
Register(kUnsortedSegmentProdOpName, {2});
Register(kSimpleMeanGradOpName, {1});

View File

@ -92,6 +92,7 @@ inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primit
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax");
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");

View File

@ -30,8 +30,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, UniqueWithPad,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity, RepeatElements)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,

View File

@ -1640,6 +1640,56 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
return out
class UnsortedSegmentMax(PrimitiveWithInfer):
"""
Computes the maximum along segments of a tensor.
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
The data type must be float16, float32 or int32.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be >= 0.
The data type must be int32.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
Outputs:
Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
>>> num_segments = 2
>>> unsorted_segment_max = P.UnsortedSegmentMax()
>>> unsorted_segment_max(input_x, segment_ids, num_segments)
[[1., 2., 3.], [4., 5., 6.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize UnsortedSegmentMax"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)
segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:]
out = {'shape': out_shape,
'dtype': x_type,
'value': None}
return out
class UnsortedSegmentProd(PrimitiveWithInfer):
"""
Computes the product along segments of a tensor.

View File

@ -0,0 +1,204 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class UnsortedSegmentMaxNet(nn.Cell):
def __init__(self, num_segments):
super(UnsortedSegmentMaxNet, self).__init__()
self.unsorted_segment_max = P.UnsortedSegmentMax()
self.num_segments = num_segments
def construct(self, data, ids):
return self.unsorted_segment_max(data, ids, self.num_segments)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_1d_int32():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
input_x = Tensor([1, 2, 3, 4], mstype.int32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxNet(num_segments)
output = net(input_x, segment_ids)
expect = [2, 3, 4, -2147483648]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_2d_int32():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], mstype.int32)
segment_ids = Tensor([2, 1, 1], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxNet(num_segments)
output = net(input_x, segment_ids)
expect = [[[-2147483648, -2147483648, -2147483648, -2147483648],
[9, 10, 11, 12],
[1, 2, 3, 4],
[-2147483648, -2147483648, -2147483648, -2147483648]]]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_float16():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16)
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
num_segments = 5
net = UnsortedSegmentMaxNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04]],
[[3.00e+01, 3.10e+01, 3.20e+01],
[3.30e+01, 3.40e+01, 3.50e+01],
[3.60e+01, 3.70e+01, 3.80e+01],
[3.90e+01, 4.00e+01, 4.10e+01],
[4.20e+01, 4.30e+01, 4.40e+01]],
[[0.00e+00, 1.00e+00, 2.00e+00],
[3.00e+00, 4.00e+00, 5.00e+00],
[6.00e+00, 7.00e+00, 8.00e+00],
[9.00e+00, 1.00e+01, 1.10e+01],
[1.20e+01, 1.30e+01, 1.40e+01]],
[[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04]],
[[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04],
[-6.55e+04, -6.55e+04, -6.55e+04]]]).astype(np.float16)
np.testing.assert_array_almost_equal(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_float32():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
num_segments = 3
net = UnsortedSegmentMaxNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init():
context.set_context(device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
net = P.UnsortedSegmentMax()
num_segments = 4
output = net(input_x, segment_ids, num_segments).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
num_segments = 6
output = net(input_x, segment_ids, num_segments).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)