forked from mindspore-Ecosystem/mindspore
!8006 [MS][GPU][CUDA][API] Adding new Ops - UnsortedSegmentMax GPU
Merge pull request !8006 from danishnxt/SegMax
This commit is contained in:
commit
de20d3e488
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_
|
||||
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UnsortedSegmentMaxGpuKernel()
|
||||
: num_segments_(1),
|
||||
inner_size_(1),
|
||||
outer_size_(1),
|
||||
input_size_(1),
|
||||
segment_ids_size_(1),
|
||||
output_size_(1),
|
||||
is_null_input_(false) {}
|
||||
~UnsortedSegmentMaxGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
int *indices_addr = GetDeviceAddress<int>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(output_addr, std::numeric_limits<T>::min(), outputs[0]->size,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemSet Failed");
|
||||
CalUnsortedSegmentMax(input_addr, indices_addr, num_segments_, outer_size_, inner_size_, output_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "UnsortedSegmentMax input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
auto segment_ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); // we get that from computation
|
||||
|
||||
num_segments_ = output_shapes[0];
|
||||
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
input_size_ *= input_shapes[i];
|
||||
}
|
||||
|
||||
segment_ids_size_ = 1;
|
||||
for (size_t i = 0; i < segment_ids_shapes.size(); i++) {
|
||||
segment_ids_size_ *= segment_ids_shapes[i];
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shapes.size(); i++) {
|
||||
output_size_ *= output_shapes[i];
|
||||
}
|
||||
|
||||
outer_size_ = input_shapes[0];
|
||||
inner_size_ = 1;
|
||||
for (size_t i = 1; i < input_shapes.size(); i++) {
|
||||
inner_size_ *= input_shapes[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
input_size_list_.push_back(segment_ids_size_ * sizeof(int));
|
||||
output_size_list_.push_back(output_size_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
int num_segments_;
|
||||
size_t inner_size_;
|
||||
size_t outer_size_;
|
||||
size_t input_size_;
|
||||
size_t segment_ids_size_;
|
||||
size_t output_size_;
|
||||
bool is_null_input_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNSORTED_SEGMENT_MAX_H_
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh"
|
||||
#include <limits>
|
||||
|
||||
template <typename T>
|
||||
__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
size_t inner_size, bool fp16_flag, T init_K, T *output) {
|
||||
if (fp16_flag) {
|
||||
init_K = __int2half_rd(-65504); // min value representable by float16
|
||||
}
|
||||
|
||||
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size;
|
||||
t_idx += blockDim.x * gridDim.x) {
|
||||
int segment_id = t_idx / KWARPSIZE / inner_size;
|
||||
int inner_id = t_idx / KWARPSIZE % inner_size;
|
||||
int lane_id = threadIdx.x % KWARPSIZE;
|
||||
T threadK = init_K;
|
||||
|
||||
for (int i = lane_id; i < outer_size; i += KWARPSIZE) {
|
||||
if (segment_ids[i] != segment_id) continue;
|
||||
T other_K = input[i * inner_size + inner_id];
|
||||
if (threadK < other_K) {
|
||||
threadK = other_K;
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) {
|
||||
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
|
||||
if (threadK < other_K) {
|
||||
threadK = other_K;
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp();
|
||||
|
||||
if (lane_id == 0) {
|
||||
output[segment_id * inner_size + inner_id] = threadK;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream) {
|
||||
int size = (inner_size * KWARPSIZE * num_segments);
|
||||
bool fp16_flag = false;
|
||||
// handle fp16 min value
|
||||
if (std::is_same<T, half>::value) {
|
||||
fp16_flag = true;
|
||||
}
|
||||
T init_K = std::numeric_limits<T>::lowest();
|
||||
UnsortedSegmentMax<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, segment_ids, num_segments, outer_size,
|
||||
inner_size, fp16_flag, init_K, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int num_segments,
|
||||
size_t outer_size, size_t inner_size, float *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int num_segments,
|
||||
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int num_segments,
|
||||
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MAX_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORTED_SEGMENT_MAX_H_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
// Setting warp size to sync data across threads
|
||||
#define KWARPSIZE 32
|
||||
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_
|
|
@ -49,6 +49,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimReduceAll->name(), {1});
|
||||
Register(prim::kPrimReduceAny->name(), {1});
|
||||
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
|
||||
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
||||
Register(kSparseGatherV2, {2});
|
||||
Register(kUnsortedSegmentProdOpName, {2});
|
||||
Register(kSimpleMeanGradOpName, {1});
|
||||
|
|
|
@ -92,6 +92,7 @@ inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primit
|
|||
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
||||
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
|
||||
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
|
||||
|
|
|
@ -30,8 +30,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, UniqueWithPad,
|
||||
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
|
||||
Unique, GatherD, Identity, RepeatElements)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
|
|
|
@ -1640,6 +1640,56 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class UnsortedSegmentMax(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the maximum along segments of a tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
|
||||
The data type must be float16, float32 or int32.
|
||||
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be >= 0.
|
||||
The data type must be int32.
|
||||
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
|
||||
|
||||
Outputs:
|
||||
Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
|
||||
>>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
|
||||
>>> num_segments = 2
|
||||
>>> unsorted_segment_max = P.UnsortedSegmentMax()
|
||||
>>> unsorted_segment_max(input_x, segment_ids, num_segments)
|
||||
[[1., 2., 3.], [4., 5., 6.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize UnsortedSegmentMax"""
|
||||
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, segment_ids, num_segments):
|
||||
x_type = x['dtype']
|
||||
x_shape = x['shape']
|
||||
segment_ids_shape = segment_ids['shape']
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32]
|
||||
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
|
||||
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
|
||||
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
|
||||
validator.check(f'first shape of input_x', x_shape[0],
|
||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||
num_segments_v = num_segments['value']
|
||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||
validator.check_positive_int(num_segments_v, "num_segments", self.name)
|
||||
segment_ids_shape_len = len(segment_ids_shape)
|
||||
out_shape = [num_segments_v]
|
||||
out_shape += x_shape[segment_ids_shape_len:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': x_type,
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class UnsortedSegmentProd(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the product along segments of a tensor.
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class UnsortedSegmentMaxNet(nn.Cell):
|
||||
def __init__(self, num_segments):
|
||||
super(UnsortedSegmentMaxNet, self).__init__()
|
||||
self.unsorted_segment_max = P.UnsortedSegmentMax()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, data, ids):
|
||||
return self.unsorted_segment_max(data, ids, self.num_segments)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_1d_int32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
input_x = Tensor([1, 2, 3, 4], mstype.int32)
|
||||
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
|
||||
num_segments = 4
|
||||
net = UnsortedSegmentMaxNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [2, 3, 4, -2147483648]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_2d_int32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor([[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12]], mstype.int32)
|
||||
segment_ids = Tensor([2, 1, 1], mstype.int32)
|
||||
num_segments = 4
|
||||
net = UnsortedSegmentMaxNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [[[-2147483648, -2147483648, -2147483648, -2147483648],
|
||||
[9, 10, 11, 12],
|
||||
[1, 2, 3, 4],
|
||||
[-2147483648, -2147483648, -2147483648, -2147483648]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_float16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16)
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 5
|
||||
net = UnsortedSegmentMaxNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04]],
|
||||
[[3.00e+01, 3.10e+01, 3.20e+01],
|
||||
[3.30e+01, 3.40e+01, 3.50e+01],
|
||||
[3.60e+01, 3.70e+01, 3.80e+01],
|
||||
[3.90e+01, 4.00e+01, 4.10e+01],
|
||||
[4.20e+01, 4.30e+01, 4.40e+01]],
|
||||
[[0.00e+00, 1.00e+00, 2.00e+00],
|
||||
[3.00e+00, 4.00e+00, 5.00e+00],
|
||||
[6.00e+00, 7.00e+00, 8.00e+00],
|
||||
[9.00e+00, 1.00e+01, 1.10e+01],
|
||||
[1.20e+01, 1.30e+01, 1.40e+01]],
|
||||
[[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04]],
|
||||
[[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
[-6.55e+04, -6.55e+04, -6.55e+04]]]).astype(np.float16)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_float32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 3
|
||||
net = UnsortedSegmentMaxNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
|
||||
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
|
||||
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
|
||||
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
|
||||
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
|
||||
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
|
||||
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
|
||||
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
|
||||
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
|
||||
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_single_init():
|
||||
context.set_context(device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
|
||||
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
|
||||
net = P.UnsortedSegmentMax()
|
||||
|
||||
num_segments = 4
|
||||
output = net(input_x, segment_ids, num_segments).asnumpy()
|
||||
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
|
||||
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
|
||||
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
|
||||
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
|
||||
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
|
||||
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
|
||||
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
|
||||
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
|
||||
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
|
||||
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
|
||||
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
|
||||
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
|
||||
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
|
||||
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
||||
num_segments = 6
|
||||
output = net(input_x, segment_ids, num_segments).asnumpy()
|
||||
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
|
||||
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
|
||||
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
|
||||
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
|
||||
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
|
||||
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
|
||||
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
|
||||
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
|
||||
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
|
||||
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
|
||||
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
|
||||
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
|
||||
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
|
||||
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
Loading…
Reference in New Issue