!12767 Add float64 support to gpu gather* grad ops

From: @peilin-wang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-05 02:40:09 +08:00 committed by Gitee
commit 87b71c1831
7 changed files with 75 additions and 8 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,14 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ScatterNdGpuFwdKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ScatterNdGpuFwdKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@
#include "backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
TransposeGpuFwdKernel, double)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TransposeGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,16 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
UnsortedSegmentSumGpuKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
UnsortedSegmentSumGpuKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
@ -39,6 +49,36 @@ MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSumGpuKernel, int, int64_t)
// Re-registration with 3 inputs - dynamic shape mode - sets of Int64/Int32 num segments types
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentSum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
UnsortedSegmentSumGpuKernel, double, int)
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentSum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
UnsortedSegmentSumGpuKernel, double, int)
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentSum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
UnsortedSegmentSumGpuKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentSum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
UnsortedSegmentSumGpuKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(UnsortedSegmentSum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -56,6 +56,14 @@ void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const
return;
}
template void ScatterNd<double, int>(int *indices, double *update, double *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<double, int64_t>(int64_t *indices, double *update, double *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
int64_t *indices_stride, int64_t *work_shape, cudaStream_t stream);
template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,9 @@ void CalTranspose(const size_t size, const T *input, const size_t *input_shape,
return;
}
template void CalTranspose<double>(const size_t size, const double *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, double *output,
cudaStream_t cuda_stream);
template void CalTranspose<float>(const size_t size, const float *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, float *output,
cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
template<typename T, typename S>
__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
@ -29,7 +30,7 @@ __global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t
continue;
}
size_t output_index = i * output_dim1 + k;
atomicAdd(output_addr + output_index, input_addr[input_index]);
MsAtomicAdd(output_addr + output_index, input_addr[input_index]);
}
}
@ -42,6 +43,11 @@ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0
return;
}
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
double* input_addr, int* ids_addr, double* output_addr, cudaStream_t stream);
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
double* input_addr, int64_t* ids_addr, double* output_addr, cudaStream_t stream);
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream);
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,

View File

@ -247,7 +247,7 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
MS_EXCEPTION_IF_NULL(segment_ids);
MS_EXCEPTION_IF_NULL(segment_ids->shape());
auto segment_ids_shape = segment_ids->shape()->shape();
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s");
(void)CheckTensorDType(x, {kFloat16, kFloat32, kFloat64, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s");
(void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentSum should be %s");
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());