add ftrl optimizer

This commit is contained in:
lizhenyu 2020-06-17 16:35:43 +08:00
parent 445122f520
commit c3360a84cd
5 changed files with 367 additions and 0 deletions

View File

@ -0,0 +1,87 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/ftrl_impl.cuh"
template <typename T>
__device__ __forceinline__ T PowFunc(T x, T y) {
return pow(x, y);
}
template <>
__device__ __forceinline__ half PowFunc(half x, half y) {
return __float2half(pow(__half2float(x), __half2float(y)));
}
template <typename T>
__device__ __forceinline__ bool CompareFunc(T x, T y) {
return abs(x) > y;
}
template <>
__device__ __forceinline__ bool CompareFunc(half x, half y) {
return abs(__half2float(x)) > __half2float(y);
}
template <typename T>
__device__ __forceinline__ T Sgn(T x) {
return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0);
}
template <>
__device__ __forceinline__ half Sgn(half x) {
return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0);
}
template <typename T>
__global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate,
const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power,
T *variable, T *accumulation, T *linear) {
const T two = static_cast<T>(2.0);
const T learning_rate_power_val = -learning_rate_power[0];
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i];
const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val);
const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val);
const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0];
linear[i] += gradient[i] - sigma * variable[i];
variable[i] = CompareFunc(linear[i], l1_regularization[0])
? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) /
(cur_accumulation_power / learning_rate[0] + two * l2_regularization[0]))
: static_cast<T>(0);
accumulation[i] = cur_accumulation;
}
}
template <typename T>
void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization,
const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear,
cudaStream_t cuda_stream) {
ApplyFtrlKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, gradient, learning_rate, l1_regularization,
l2_regularization, learning_rate_power, variable,
accumulation, linear);
}
template void ApplyFtrl<float>(const size_t size, const float *gradient, const float *learning_rate,
const float *l1_regularization, const float *l2_regularization,
const float *learning_rate_power, float *variable, float *accumulation, float *linear,
cudaStream_t cuda_stream);
template void ApplyFtrl<half>(const size_t size, const half *gradient, const half *learning_rate,
const half *l1_regularization, const half *l2_regularization,
const half *learning_rate_power, half *variable, half *accumulation, half *linear,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_
#include "device/gpu/cuda_common.h"
template <typename T>
void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization,
const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/ftrl_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FtrlGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
FtrlGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/ftrl_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class FtrlGpuKernel : public GpuKernel {
public:
FtrlGpuKernel()
: variable_size_(0),
accumulation_size_(0),
linear_size_(0),
gradient_size_(0),
learning_rate_size_(0),
l1_regularization_size_(0),
l2_regularization_size_(0),
learning_rate_power_size_(0) {}
~FtrlGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
T *linear = GetDeviceAddress<T>(inputs, 2);
T *gradient = GetDeviceAddress<T>(inputs, 3);
T *learning_rate = GetDeviceAddress<T>(inputs, 4);
T *l1_regularization = GetDeviceAddress<T>(inputs, 5);
T *l2_regularization = GetDeviceAddress<T>(inputs, 6);
T *learning_rate_power = GetDeviceAddress<T>(inputs, 7);
ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization,
learning_rate_power, variable, accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 8) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs.";
return false;
}
variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
linear_size_ = sizeof(T);
gradient_size_ = sizeof(T);
learning_rate_size_ = sizeof(T);
l1_regularization_size_ = sizeof(T);
l2_regularization_size_ = sizeof(T);
learning_rate_power_size_ = sizeof(T);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
}
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < accumulation_shape.size(); i++) {
accumulation_size_ *= accumulation_shape[i];
}
auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < linear_shape.size(); i++) {
linear_size_ *= linear_shape[i];
}
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(accumulation_size_);
input_size_list_.push_back(linear_size_);
input_size_list_.push_back(gradient_size_);
input_size_list_.push_back(learning_rate_size_);
input_size_list_.push_back(l1_regularization_size_);
input_size_list_.push_back(l2_regularization_size_);
input_size_list_.push_back(learning_rate_power_size_);
output_size_list_.push_back(0);
}
private:
size_t variable_size_;
size_t accumulation_size_;
size_t linear_size_;
size_t gradient_size_;
size_t learning_rate_size_;
size_t l1_regularization_size_;
size_t l2_regularization_size_;
size_t learning_rate_power_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_

View File

@ -0,0 +1,78 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import Dense
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import FTRL
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetFtrl(nn.Cell):
def __init__(self):
super(NetFtrl, self).__init__()
self.batch_size = 1
self.reshape = P.Reshape()
weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
self.fc1 = Dense(16, 10, weight_init=weight)
def construct(self, input_x):
output = self.reshape(input_x, (self.batch_size, -1))
output = self.fc1(output)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl():
epoch = 3
net = NetFtrl()
optimizer = FTRL(filter(lambda x: x.requires_grad,
net.get_parameters()), learning_rate=0.01)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer)
train_network.set_train()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
losses1 = []
for _ in range(epoch):
data = Tensor(np.arange(0, 16).reshape(
1, 1, 4, 4).astype(np.float32) * 0.01)
label = Tensor(np.array([0]).astype(np.int32))
loss = train_network(data, label)
losses1.append(loss.asnumpy())
assert losses1[0] > losses1[1]
assert losses1[1] > losses1[2]
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
losses2 = []
for _ in range(epoch):
data = Tensor(np.arange(0, 16).reshape(
1, 1, 4, 4).astype(np.float32) * 0.01)
label = Tensor(np.array([0]).astype(np.int32))
loss = train_network(data, label)
losses2.append(loss.asnumpy())
assert losses2[0] > losses2[1]
assert losses2[1] > losses2[2]