!9620 [MS][GPU] UnsortedSegmentMax_int64_Support
From: @danishnxt Reviewed-by: @robingrosman,@tom__chen Signed-off-by:
This commit is contained in:
commit
d68708960e
|
@ -18,60 +18,131 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMaxGpuKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMaxGpuKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
UnsortedSegmentMaxGpuKernel, int, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int, int64_t)
|
||||
|
||||
// Dynamic Mode - registered for int32/int64 3rd input
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
UnsortedSegmentMaxGpuKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
UnsortedSegmentMaxGpuKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
UnsortedSegmentMaxGpuKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
UnsortedSegmentMaxGpuKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
UnsortedSegmentMaxGpuKernel, int, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
UnsortedSegmentMaxGpuKernel, int, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
template <typename T, typename S>
|
||||
class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UnsortedSegmentMaxGpuKernel() { ResetResource(); }
|
||||
|
@ -41,7 +41,7 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
int *indices_addr = GetDeviceAddress<int>(inputs, 1);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
|
|
|
@ -17,21 +17,21 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh"
|
||||
#include <limits>
|
||||
|
||||
template <typename T>
|
||||
__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, bool fp16_flag, T init_K, T *output) {
|
||||
template <typename T, typename S>
|
||||
__global__ void UnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, bool fp16_flag, T init_K, T *output) {
|
||||
if (fp16_flag) {
|
||||
init_K = __int2half_rd(-65504); // min value representable by float16
|
||||
}
|
||||
|
||||
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size;
|
||||
for (size_t t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size;
|
||||
t_idx += blockDim.x * gridDim.x) {
|
||||
int segment_id = t_idx / KWARPSIZE / inner_size;
|
||||
int inner_id = t_idx / KWARPSIZE % inner_size;
|
||||
int lane_id = threadIdx.x % KWARPSIZE;
|
||||
size_t segment_id = t_idx / KWARPSIZE / inner_size;
|
||||
size_t inner_id = t_idx / KWARPSIZE % inner_size;
|
||||
size_t lane_id = threadIdx.x % KWARPSIZE;
|
||||
T threadK = init_K;
|
||||
|
||||
for (int i = lane_id; i < outer_size; i += KWARPSIZE) {
|
||||
for (size_t i = lane_id; i < outer_size; i += KWARPSIZE) {
|
||||
if (segment_ids[i] != segment_id) continue;
|
||||
T other_K = input[i * inner_size + inner_id];
|
||||
if (threadK < other_K) {
|
||||
|
@ -40,7 +40,7 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const
|
|||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) {
|
||||
for (size_t offset = KWARPSIZE / 2; offset > 0; offset /= 2) {
|
||||
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
|
||||
if (threadK < other_K) {
|
||||
threadK = other_K;
|
||||
|
@ -56,10 +56,10 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
template <typename T, typename S>
|
||||
void CalUnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream) {
|
||||
int size = (inner_size * KWARPSIZE * num_segments);
|
||||
size_t size = (inner_size * KWARPSIZE * num_segments);
|
||||
bool fp16_flag = false;
|
||||
// handle fp16 min value
|
||||
if (std::is_same<T, half>::value) {
|
||||
|
@ -71,9 +71,19 @@ void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t
|
|||
return;
|
||||
}
|
||||
|
||||
template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, float *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<float, int>(const float *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, float *output,
|
||||
cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<float, int64_t>(const float *input, const int64_t *segment_ids,
|
||||
const int64_t num_segments, size_t outer_size, size_t inner_size,
|
||||
float *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<half, int>(const half *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<half, int64_t>(const half *input, const int64_t *segment_ids,
|
||||
const int64_t num_segments, size_t outer_size, size_t inner_size,
|
||||
half *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<int, int>(const int *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<int, int64_t>(const int *input, const int64_t *segment_ids,
|
||||
const int64_t num_segments, size_t outer_size, size_t inner_size,
|
||||
int *output, cudaStream_t stream);
|
||||
|
|
|
@ -22,8 +22,8 @@
|
|||
|
||||
// Setting warp size to sync data across threads
|
||||
#define KWARPSIZE 32
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
template <typename T, typename S>
|
||||
void CalUnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_
|
||||
|
|
|
@ -306,7 +306,7 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri
|
|||
MS_EXCEPTION_IF_NULL(segment_ids->shape());
|
||||
auto segment_ids_shape = segment_ids->shape()->shape();
|
||||
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s");
|
||||
(void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
|
||||
(void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
|
||||
// check if dynamic shape
|
||||
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
|
||||
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
|
||||
|
|
|
@ -2001,7 +2001,8 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
|
|||
segment_ids_shape = segment_ids['shape']
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32]
|
||||
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
|
||||
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
|
||||
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']},
|
||||
[mstype.int32, mstype.int64], self.name)
|
||||
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
|
||||
|
|
|
@ -71,12 +71,12 @@ def test_2d_int32():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_float16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
def test_3d_float16_int64():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16)
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 5
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int64)
|
||||
num_segments = Tensor(5, dtype=mstype.int64)
|
||||
net = UnsortedSegmentMaxNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04],
|
||||
|
@ -110,12 +110,12 @@ def test_3d_float16():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_float32():
|
||||
def test_3d_float32_int64():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 3
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int64)
|
||||
num_segments = Tensor(3, dtype=mstype.int64)
|
||||
net = UnsortedSegmentMaxNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
|
|
Loading…
Reference in New Issue