sparse apply ftrl gpu kernel

This commit is contained in:
tom__chen 2020-10-22 07:19:57 -04:00
parent 8aaa20f918
commit 71b235c302
5 changed files with 467 additions and 0 deletions

View File

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sparse_ftrl_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template <typename T>
__device__ __forceinline__ T PowFunc(T x, T y) {
return pow(x, y);
}
template <>
__device__ __forceinline__ half PowFunc(half x, half y) {
return __float2half(pow(__half2float(x), __half2float(y)));
}
template <typename T>
__device__ __forceinline__ bool CompareFunc(T x, T y) {
return abs(x) > y;
}
template <>
__device__ __forceinline__ bool CompareFunc(half x, half y) {
return abs(__half2float(x)) > __half2float(y);
}
template <typename T>
__device__ __forceinline__ T Sgn(T x) {
return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0);
}
template <>
__device__ __forceinline__ half Sgn(half x) {
return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0);
}
template <typename T, typename S>
__global__ void SparseApplyFtrlKernel(const T *gradient, const S *indices, const int num_index, const size_t n_stride,
const float learning_rate, const float l1_regularization,
const float l2_regularization, const float learning_rate_power,
T *variable, T *accumulation, T *linear) {
const T two = static_cast<T>(2.0);
const T learning_rate_val = static_cast<T>(learning_rate);
const T l1_regularization_val = static_cast<T>(l1_regularization);
const T l2_regularization_val = static_cast<T>(l2_regularization);
const T learning_rate_power_val = static_cast<T>(-learning_rate_power);
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < (num_index*n_stride);
pos += gridDim.x * blockDim.x) {
const int posn = pos / n_stride;
const int posi = pos % n_stride;
const int indexed_n = indices[posn];
const int i = indexed_n*n_stride + posi;
const T cur_accumulation = accumulation[i] + gradient[pos] * gradient[pos];
const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val);
const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val);
const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate_val;
linear[i] += gradient[pos] - sigma * variable[i];
variable[i] = CompareFunc(linear[i], l1_regularization_val)
? ((l1_regularization_val * Sgn(linear[i]) - linear[i]) /
(cur_accumulation_power / learning_rate_val + two * l2_regularization_val))
: static_cast<T>(0);
accumulation[i] = cur_accumulation;
}
return;
}
template <typename T, typename S>
void CalSparseApplyFtrl(const T *gradient, const S *indices, const int num_index, const size_t n_stride,
const float learning_rate, const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, T *variable, T *accumulation,
T *linear, cudaStream_t cuda_stream) {
SparseApplyFtrlKernel<<<GET_BLOCKS(num_index*n_stride), GET_THREADS, 0, cuda_stream>>>(gradient, indices, num_index,
n_stride, learning_rate, l1_regularization, l2_regularization, learning_rate_power, variable, accumulation, linear);
return;
}
template void CalSparseApplyFtrl<float, int>(const float *gradient, const int *indices, const int num_index,
const size_t n_stride, const float learning_rate,
const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, float *variable,
float *accumulation, float *linear, cudaStream_t cuda_stream);
template void CalSparseApplyFtrl<half, int>(const half *gradient, const int *indices, const int num_index,
const size_t n_stride, const float learning_rate,
const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, half *variable,
half *accumulation, half *linear, cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPARSE_FTRL_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPARSE_FTRL_IMPL_H_
template <typename T, typename S>
void CalSparseApplyFtrl(const T *gradient, const S *indices, const int num_index, const size_t n_stride,
const float learning_rate, const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, T *variable, T *accumulation,
T *linear, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SPARSE_FTRL_IMPL_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SparseFtrlGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SparseFtrlGpuKernel, half, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,146 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class SparseFtrlGpuKernel : public GpuKernel {
public:
SparseFtrlGpuKernel()
: variable_size_(0),
accumulation_size_(0),
linear_size_(0),
gradient_size_(0),
indices_size_(0),
lr_(0.0f),
l1_(0.0f),
l2_(0.0f),
lr_power_(0.0f),
use_locking_(false),
num_index_(0),
n_stride_(1) {}
~SparseFtrlGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
T *linear = GetDeviceAddress<T>(inputs, 2);
T *gradient = GetDeviceAddress<T>(inputs, 3);
S *indices = GetDeviceAddress<S>(inputs, 4);
CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable,
accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but sparse ftrl needs 5 inputs.";
return false;
}
variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
linear_size_ = sizeof(T);
gradient_size_ = sizeof(T);
indices_size_ = sizeof(S);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
if (i > 0) {
n_stride_ *= variable_shape[i];
}
}
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < accumulation_shape.size(); i++) {
accumulation_size_ *= accumulation_shape[i];
}
auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < linear_shape.size(); i++) {
linear_size_ *= linear_shape[i];
}
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i];
}
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4);
for (size_t i = 0; i < indices_shape.size(); i++) {
indices_size_ *= indices_shape[i];
}
lr_ = GetAttr<float>(kernel_node, "lr");
l1_ = GetAttr<float>(kernel_node, "l1");
l2_ = GetAttr<float>(kernel_node, "l2");
lr_power_ = GetAttr<float>(kernel_node, "lr_power");
use_locking_ = GetAttr<bool>(kernel_node, "use_locking");
num_index_ = indices_shape[0];
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(accumulation_size_);
input_size_list_.push_back(linear_size_);
input_size_list_.push_back(gradient_size_);
input_size_list_.push_back(indices_size_);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
}
private:
size_t variable_size_;
size_t accumulation_size_;
size_t linear_size_;
size_t gradient_size_;
size_t indices_size_;
float lr_;
float l1_;
float l2_;
float lr_power_;
bool use_locking_;
int num_index_;
size_t n_stride_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_

View File

@ -0,0 +1,149 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
return out
class Net_half(nn.Cell):
def __init__(self):
super(Net_half, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl():
gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32))
indices = Tensor([0, 1, 2], mstype.int32)
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl_sparse():
gradient = Tensor(np.ones([2, 3, 3]).astype(np.float32))
indices = Tensor([0, 2], mstype.int32)
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl_half():
gradient = Tensor(np.ones([3, 3, 3]).astype(np.float16))
indices = Tensor([0, 1, 2], mstype.int32)
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl_sparse_half():
gradient = Tensor(np.ones([2, 3, 3]).astype(np.float16))
indices = Tensor([0, 2], mstype.int32)
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)