forked from mindspore-Ecosystem/mindspore
add gpu ApplyGradientDescent
This commit is contained in:
parent
25b30f34a6
commit
593c68e110
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyGradientDescent(const size_t size, T *var, const T *alpha, const T *delta, T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
const T alpha_value = alpha[0];
|
||||
var[pos] -= alpha_value * delta[pos];
|
||||
output[pos] = var[pos];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalApplyGradientDescent(const size_t &size, T *var, const T *alpha, const T *delta, T *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
ApplyGradientDescent<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, var, alpha, delta, output);
|
||||
}
|
||||
|
||||
template void CalApplyGradientDescent<float>(const size_t &size, float *var, const float *alpha, const float *delta,
|
||||
float *output, cudaStream_t cuda_stream);
|
||||
template void CalApplyGradientDescent<half>(const size_t &size, half *var, const half *alpha, const half *delta,
|
||||
half *output, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_APPLY_GRADIENT_DESCENT_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_APPLY_GRADIENT_DESCENT_IMPL_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void CalApplyGradientDescent(const size_t &size, T *var, const T *alpha, const T *delta, T *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_APPLY_GRADIENT_DESCENT_IMPL_CUH_
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ApplyGradientDescent,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ApplyGradientDescentKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ApplyGradientDescent,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
ApplyGradientDescentKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_GRADIENT_DESCENT_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_GRADIENT_DESCENT_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ApplyGradientDescentKernel : public GpuKernel {
|
||||
public:
|
||||
ApplyGradientDescentKernel() { ResetResource(); }
|
||||
~ApplyGradientDescentKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *var = GetDeviceAddress<T>(inputs, 0);
|
||||
T *alpha = GetDeviceAddress<T>(inputs, 1);
|
||||
T *delta = GetDeviceAddress<T>(inputs, 2);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
CalApplyGradientDescent(input_size_, var, alpha, delta, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ApplyGradientDescent needs 3 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ApplyGradientDescent has 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 1;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(input_size_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_GRADIENT_DESCENT_GPU_KERNEL_H_
|
|
@ -995,7 +995,7 @@ class BCEWithLogitsLoss(_Loss):
|
|||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
|
|
|
@ -3832,7 +3832,7 @@ class BCEWithLogitsLoss(PrimitiveWithInfer):
|
|||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
|
@ -6268,7 +6268,7 @@ class ApplyGradientDescent(PrimitiveWithInfer):
|
|||
TypeError: If `alpha` is neither a Number nor a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, var):
|
||||
super(Net, self).__init__()
|
||||
self.var = Parameter(var, name="var")
|
||||
self.apply_gradient_descent = P.ApplyGradientDescent()
|
||||
|
||||
def construct(self, alpha, delta):
|
||||
return self.apply_gradient_descent(self.var, alpha, delta)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_gradient_descent_float32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
var = Tensor(np.arange(10).reshape(2, 5).astype(np.float32) / 10)
|
||||
net = Net(var)
|
||||
alpha = Tensor(np.array([0.0001]).astype(np.float32))
|
||||
delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float32))
|
||||
output = net(alpha, delta)
|
||||
expect = np.array([[-0.0034, 0.0965, 0.1964, 0.29630002, 0.3962],
|
||||
[0.4961, 0.596, 0.69589996, 0.79580003, 0.8957]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
np.testing.assert_almost_equal(net.var.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
var = Tensor(np.arange(10).reshape(2, 5).astype(np.float32) / 10)
|
||||
net = Net(var)
|
||||
alpha = Tensor(np.array([0.0001]).astype(np.float32))
|
||||
delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float32))
|
||||
output = net(alpha, delta)
|
||||
expect = np.array([[-0.0034, 0.0965, 0.1964, 0.29630002, 0.3962],
|
||||
[0.4961, 0.596, 0.69589996, 0.79580003, 0.8957]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
np.testing.assert_almost_equal(net.var.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_gradient_descent_float16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
var = Tensor(np.arange(10).reshape(2, 5).astype(np.float16) / 10)
|
||||
net = Net(var)
|
||||
alpha = Tensor(np.array([0.0001]).astype(np.float16))
|
||||
delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float16))
|
||||
output = net(alpha, delta)
|
||||
expect = np.array([[-0.0034, 0.0965, 0.1964, 0.29630002, 0.3962],
|
||||
[0.4961, 0.596, 0.69589996, 0.79580003, 0.8957]], dtype=np.float16)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=3)
|
||||
np.testing.assert_almost_equal(net.var.asnumpy(), expect, decimal=3)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
var = Tensor(np.arange(10).reshape(2, 5).astype(np.float16) / 10)
|
||||
net = Net(var)
|
||||
alpha = Tensor(np.array([0.0001]).astype(np.float16))
|
||||
delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float16))
|
||||
output = net(alpha, delta)
|
||||
expect = np.array([[-0.0034, 0.0965, 0.1964, 0.2964, 0.396],
|
||||
[0.496, 0.596, 0.6963, 0.7954, 0.8955]], dtype=np.float16)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=3)
|
||||
np.testing.assert_almost_equal(net.var.asnumpy(), expect, decimal=3)
|
Loading…
Reference in New Issue