!9857 [MS][GPU] SparseApplyFtrl_int64_Support
From: @danishnxt Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
904f61b7fd
|
@ -96,8 +96,19 @@ template void CalSparseApplyFtrl<float, int>(const float *gradient, const int *i
|
|||
const float l1_regularization, const float l2_regularization,
|
||||
const float learning_rate_power, const bool use_locking, float *variable,
|
||||
float *accumulation, float *linear, cudaStream_t cuda_stream);
|
||||
template void CalSparseApplyFtrl<float, int64_t>(const float *gradient, const int64_t *indices, const int num_index,
|
||||
const size_t n_stride, const float learning_rate,
|
||||
const float l1_regularization, const float l2_regularization,
|
||||
const float learning_rate_power, const bool use_locking, float *variable,
|
||||
float *accumulation, float *linear, cudaStream_t cuda_stream);
|
||||
template void CalSparseApplyFtrl<half, int>(const half *gradient, const int *indices, const int num_index,
|
||||
const size_t n_stride, const float learning_rate,
|
||||
const float l1_regularization, const float l2_regularization,
|
||||
const float learning_rate_power, const bool use_locking, half *variable,
|
||||
half *accumulation, half *linear, cudaStream_t cuda_stream);
|
||||
template void CalSparseApplyFtrl<half, int64_t>(const half *gradient, const int64_t *indices, const int num_index,
|
||||
const size_t n_stride, const float learning_rate,
|
||||
const float l1_regularization, const float l2_regularization,
|
||||
const float learning_rate_power, const bool use_locking, half *variable,
|
||||
half *accumulation, half *linear, cudaStream_t cuda_stream);
|
||||
|
||||
|
|
|
@ -29,6 +29,17 @@ MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseFtrlGpuKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseFtrlGpuKernel, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
@ -40,5 +51,16 @@ MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
|
|||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SparseFtrlGpuKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SparseFtrlGpuKernel, half, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -77,9 +77,9 @@ def test_ftrl():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ftrl_sparse():
|
||||
def test_ftrl_sparse_int64_ind():
|
||||
gradient = Tensor(np.ones([2, 3, 3]).astype(np.float32))
|
||||
indices = Tensor([0, 2], mstype.int32)
|
||||
indices = Tensor([0, 2], mstype.int64)
|
||||
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]],
|
||||
|
@ -127,9 +127,9 @@ def test_ftrl_half():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ftrl_sparse_half():
|
||||
def test_ftrl_sparse_half_int64_ind():
|
||||
gradient = Tensor(np.ones([2, 3, 3]).astype(np.float16))
|
||||
indices = Tensor([0, 2], mstype.int32)
|
||||
indices = Tensor([0, 2], mstype.int64)
|
||||
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]],
|
||||
|
|
Loading…
Reference in New Issue