forked from mindspore-Ecosystem/mindspore
!322 Gpu Support RMSProp kernel
Merge pull request !322 from chenweifeng/rmsprop
This commit is contained in:
commit
bda4ebd591
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable,
|
||||
T* mean_square, T*moment, T* gradients, const size_t size) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i];
|
||||
moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i];
|
||||
variable[i] -= moment[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon,
|
||||
T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) {
|
||||
RmsPropKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon,
|
||||
variable, mean_square, moment, gradients, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon,
|
||||
T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients,
|
||||
const size_t size) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i];
|
||||
mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i];
|
||||
moment[i] = momentum[0] * moment[i] + learning_rate[0] *
|
||||
rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i];
|
||||
variable[i] -= moment[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable,
|
||||
T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size,
|
||||
cudaStream_t cuda_stream) {
|
||||
RmsPropCenterKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon,
|
||||
variable, mean_gradients, mean_square,
|
||||
moment, gradients, size);
|
||||
}
|
||||
|
||||
template
|
||||
void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon,
|
||||
float* variable, float* mean_square, float* moment, float* gradients, const size_t size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template
|
||||
void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon,
|
||||
float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients,
|
||||
const size_t size, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square,
|
||||
T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable,
|
||||
T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/rmsprop_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ApplyRMSProp,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
RMSPropGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
RMSPropGpuKernel, float)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class RMSPropGpuKernel : public GpuKernel {
|
||||
public:
|
||||
RMSPropGpuKernel() : size_(1), use_center_(false) {}
|
||||
~RMSPropGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream) override {
|
||||
if (!use_center_) {
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *mean_square = GetDeviceAddress<T>(inputs, 1);
|
||||
T *moment = GetDeviceAddress<T>(inputs, 2);
|
||||
T *gradients = GetDeviceAddress<T>(inputs, 3);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, 4);
|
||||
T *decay = GetDeviceAddress<T>(inputs, 5);
|
||||
T *momentum = GetDeviceAddress<T>(inputs, 6);
|
||||
T *epsilon = GetDeviceAddress<T>(inputs, 7);
|
||||
|
||||
RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_,
|
||||
reinterpret_cast<cudaStream_t>(stream));
|
||||
} else {
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *mean_gradients = GetDeviceAddress<T>(inputs, 1);
|
||||
T *mean_square = GetDeviceAddress<T>(inputs, 2);
|
||||
T *moment = GetDeviceAddress<T>(inputs, 3);
|
||||
T *gradients = GetDeviceAddress<T>(inputs, 4);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, 5);
|
||||
T *decay = GetDeviceAddress<T>(inputs, 6);
|
||||
T *momentum = GetDeviceAddress<T>(inputs, 7);
|
||||
T *epsilon = GetDeviceAddress<T>(inputs, 8);
|
||||
|
||||
RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients,
|
||||
size_, reinterpret_cast<cudaStream_t>(stream));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (node_name == "ApplyCenteredRMSProp") {
|
||||
use_center_ = true;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (auto &dim : input_shape) {
|
||||
size_ *= dim;
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = size_ * sizeof(T);
|
||||
input_size_list_.push_back(input_size);
|
||||
if (use_center_) {
|
||||
input_size_list_.push_back(input_size);
|
||||
}
|
||||
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(0);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
bool use_center_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetRMSProp(nn.Cell):
|
||||
def __init__(self, use_centered):
|
||||
super(NetRMSProp, self).__init__()
|
||||
self.use_centered = use_centered
|
||||
if use_centered:
|
||||
self.rms_opt = P.ApplyCenteredRMSProp()
|
||||
else:
|
||||
self.rms_opt = P.ApplyRMSProp()
|
||||
|
||||
def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon):
|
||||
if self.use_centered:
|
||||
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
|
||||
else:
|
||||
return self.rms_opt(var, rms, mom, g, lr, decay, momentum, epsilon)
|
||||
|
||||
def rmsprop_numpy(variable, gradients, mean_square, moment,
|
||||
learning_rate, decay, momentum, epsilon):
|
||||
mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients
|
||||
moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients
|
||||
variable = variable - moment
|
||||
|
||||
def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment,
|
||||
learning_rate, decay, momentum, epsilon):
|
||||
mean_gradients = mean_gradients * decay + (1.0 - decay) * gradients
|
||||
mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients
|
||||
moment = momentum * moment + learning_rate / np.sqrt(mean_square -mean_gradients * mean_gradients + epsilon) * gradients
|
||||
variable = variable - moment
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_rmsprop():
|
||||
learning_rate, decay, momentum, epsilon, centered = [0.5, 0.8, 0.9, 1e-3, True]
|
||||
|
||||
variable_np = np.array([1.0, 2.0], dtype=np.float32)
|
||||
gradients_np = np.array([0.1, 0.2], dtype=np.float32)
|
||||
mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32)
|
||||
mean_square_np = np.array([epsilon, epsilon], dtype=np.float32)
|
||||
moment_np = np.array([0.0, 0.0], dtype=np.float32)
|
||||
|
||||
variable_ms = Tensor(variable_np)
|
||||
gradients_ms = Tensor(gradients_np)
|
||||
mean_gradients_ms = Tensor(mean_gradients_np)
|
||||
mean_square_ms = Tensor(mean_square_np)
|
||||
moment_ms = Tensor(moment_np)
|
||||
|
||||
if centered:
|
||||
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
else:
|
||||
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
|
||||
net = NetRMSProp(centered)
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms,
|
||||
moment_ms, learning_rate, decay, momentum, epsilon)
|
||||
|
||||
error = np.ones(shape=variable_np.shape) * 10e-6
|
||||
diff = variable_ms.asnumpy() - variable_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=gradients_np.shape) * 10e-6
|
||||
diff = gradients_ms.asnumpy() - gradients_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=mean_gradients_np.shape) * 10e-6
|
||||
diff = mean_gradients_ms.asnumpy() - mean_gradients_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=mean_square_np.shape) * 10e-6
|
||||
diff = mean_square_ms.asnumpy() - mean_square_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=moment_np.shape) * 10e-6
|
||||
diff = moment_ms.asnumpy() - moment_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_rmspropcenter():
|
||||
learning_rate, decay, momentum, epsilon, centered = [0.1, 0.3, 0.9, 1.0, False]
|
||||
|
||||
variable_np = np.array([1.0, 2.0], dtype=np.float32)
|
||||
gradients_np = np.array([0.1, 0.2], dtype=np.float32)
|
||||
mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32)
|
||||
mean_square_np = np.array([epsilon, epsilon], dtype=np.float32)
|
||||
moment_np = np.array([0.0, 0.0], dtype=np.float32)
|
||||
|
||||
variable_ms = Tensor(variable_np)
|
||||
gradients_ms = Tensor(gradients_np)
|
||||
mean_gradients_ms = Tensor(mean_gradients_np)
|
||||
mean_square_ms = Tensor(mean_square_np)
|
||||
moment_ms = Tensor(moment_np)
|
||||
|
||||
if centered:
|
||||
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
else:
|
||||
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
|
||||
net = NetRMSProp(centered)
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
|
||||
error = np.ones(shape=variable_np.shape) * 10e-6
|
||||
diff = variable_ms.asnumpy() - variable_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=gradients_np.shape) * 10e-6
|
||||
diff = gradients_ms.asnumpy() - gradients_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=mean_gradients_np.shape) * 10e-6
|
||||
diff = mean_gradients_ms.asnumpy() - mean_gradients_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=mean_square_np.shape) * 10e-6
|
||||
diff = mean_square_ms.asnumpy() - mean_square_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
error = np.ones(shape=moment_np.shape) * 10e-6
|
||||
diff = moment_ms.asnumpy() - moment_np
|
||||
assert np.all(diff < error)
|
Loading…
Reference in New Issue