new applyproximalgradientdescent
This commit is contained in:
parent
b613a4a5ed
commit
bfd3539586
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_proximal_adagrad_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_proximal_gradient_descent_impl.cuh"
|
||||
#include <algorithm>
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
|
@ -58,10 +58,9 @@ __device__ __forceinline__ half SgnFunc(half x) {
|
|||
return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalApplyProximalGradientDescentKernel(const size_t input_elements,T *var, const T *alpha, const T *l1, const T *l2,
|
||||
const T *delta, T *output) {
|
||||
__global__ void CalApplyProximalGradientDescentKernel(const size_t input_elements, T *var, const T *alpha,
|
||||
const T *l1, const T *l2, const T *delta, T *output) {
|
||||
if (l1[0] > static_cast<T>(0.0)) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast<int>(input_elements);
|
||||
pos += gridDim.x * blockDim.x) {
|
||||
|
@ -81,16 +80,25 @@ __global__ void CalApplyProximalGradientDescentKernel(const size_t input_element
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CalApplyProximalGradientDescent(const size_t input_elements,T *var, const T *alpha, const T *l1, const T *l2, const T *delta, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
CalApplyProximalGradientDescentKernel<<<CUDA_BLOCKS(device_id, input_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
void CalApplyProximalGradientDescent(const size_t input_elements, T *var, const T *alpha,
|
||||
const T *l1, const T *l2, const T *delta, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
CalApplyProximalGradientDescentKernel<<<CUDA_BLOCKS(device_id, input_elements),
|
||||
CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input_elements, var, alpha, l1, l2, delta, output);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalApplyProximalGradientDescent<float>(const size_t size, float *var, const float *alpha, const float *l1,
|
||||
const float *l2, const float *delta, float *output,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalApplyProximalGradientDescent<half>(const size_t size, half *var, const half *alpha, const half *l1,
|
||||
template CUDA_LIB_EXPORT void CalApplyProximalGradientDescent<float>(const size_t size, float *var,
|
||||
const float *alpha, const float *l1,
|
||||
const float *l2, const float *delta, float *output,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalApplyProximalGradientDescent<half>(const size_t size, half *var,
|
||||
const half *alpha, const half *l1,
|
||||
const half *l2, const half *delta, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalApplyProximalGradientDescent<double>(const size_t size, double *var,
|
||||
const double *alpha, const double *l1,
|
||||
const double *l2, const double *delta, double *output,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_PROXIMAL_GRADIENT_DESCENT_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_PROXIMAL_GRADIENT_DESCENT_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalApplyProximalGradientDescent(const size_t size, T *var, const T *alpha, const T *l1, const T *l2, const T *delta,
|
||||
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT void CalApplyProximalGradientDescent(const size_t size, T *var, const T *alpha,
|
||||
const T *l1, const T *l2, const T *delta,
|
||||
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_PROXIMAL_GRADIENT_DESCENT_IMPL_CUH_
|
||||
|
|
|
@ -30,12 +30,11 @@ constexpr size_t kL1Index = 2;
|
|||
constexpr size_t kL2Index = 3;
|
||||
constexpr size_t kDeltaIndex = 4;
|
||||
constexpr size_t kOutputIndex = 0;
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ApplyProximalGradientDescentGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
@ -50,14 +49,13 @@ bool ApplyProximalGradientDescentGpuKernelMod::Init(const BaseOperatorPtr &base_
|
|||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int ApplyProximalGradientDescentGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
|
@ -108,38 +106,49 @@ int ApplyProximalGradientDescentGpuKernelMod::Resize(const BaseOperatorPtr &base
|
|||
return ret;
|
||||
}
|
||||
|
||||
// bool ApplyProximalGradientDescentGpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
// const std::vector<kernel::AddressPtr> &workspace,
|
||||
// const std::vector<kernel::AddressPtr> &outputs,
|
||||
// void *stream_ptr) {
|
||||
// kernel_func_(this, inputs,workspace, outputs);
|
||||
// return true;
|
||||
// }
|
||||
|
||||
template <typename T>
|
||||
bool ApplyProximalGradientDescentGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) {
|
||||
bool ApplyProximalGradientDescentGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto var = reinterpret_cast<T *>(inputs[kVarIndex]->addr);
|
||||
auto alpha = reinterpret_cast<T *>(inputs[kAlphaIndex]->addr);
|
||||
auto l1 = reinterpret_cast<T *>(inputs[kL1Index]->addr);
|
||||
auto l2 = reinterpret_cast<T *>(inputs[kL2Index]->addr);
|
||||
auto delta = reinterpret_cast<T *>(inputs[kDeltaIndex]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
|
||||
// T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CalApplyProximalGradientDescent(input_elements_, var, alpha, l1, l2, delta, output, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, ApplyProximalGradientDescentGpuKernelMod::KernelFunc>>
|
||||
ApplyProximalGradientDescentGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<half>}};
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&ApplyProximalGradientDescentGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> ApplyProximalGradientDescentGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -42,9 +42,9 @@ class ApplyProximalGradientDescentGpuKernelMod : public NativeGpuKernelMod {
|
|||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override{
|
||||
cuda_stream_ = stream_ptr;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
cuda_stream_ = stream_ptr;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -52,9 +52,11 @@ class ApplyProximalGradientDescentGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs);
|
||||
using KernelFunc = std::function<bool(ApplyProximalGradientDescentGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> &)>;
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using KernelFunc =
|
||||
std::function<bool(ApplyProximalGradientDescentGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
KernelFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, KernelFunc>> func_list_;
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ TypePtr ApplyProximalGradientDescentInferType(const PrimitivePtr &prim,
|
|||
auto l1_type = input_args[kInputIndex2]->BuildType();
|
||||
auto l2_type = input_args[kInputIndex3]->BuildType();
|
||||
auto delta_type = input_args[kInputIndex4]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
// var, delta must have the same type as var
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert(std::make_pair("var_type", var_type));
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,15 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -37,6 +34,11 @@ class Net(nn.Cell):
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_proximal_gradient_descent_float32():
|
||||
"""
|
||||
Feature: ApplyProximalGradientDescent gpu kernel
|
||||
Description: test the ApplyProximalGradientDescent.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
var = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
net = Net(var)
|
||||
|
@ -45,7 +47,7 @@ def test_apply_proximal_gradient_descent_float32():
|
|||
l2 = 0.1
|
||||
delta = Tensor(np.array([[0.1, 0.1], [0.1, 0.1]]).astype(np.float32))
|
||||
output = net(alpha, l1, l2, delta)
|
||||
expect = np.array([[0.99969995,0.99969995],[0.99969995,0.99969995]], dtype=np.float32)
|
||||
expect = np.array([[0.99969995, 0.99969995], [0.99969995, 0.99969995]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
@ -56,6 +58,6 @@ def test_apply_proximal_gradient_descent_float32():
|
|||
l2 = 0.1
|
||||
delta = Tensor(np.array([[0.1, 0.1], [0.1, 0.1]]).astype(np.float32))
|
||||
output = net(alpha, l1, l2, delta)
|
||||
expect = np.array([[0.99969995,0.99969995],[0.99969995,0.99969995]], dtype=np.float32)
|
||||
expect = np.array([[0.99969995, 0.99969995], [0.99969995, 0.99969995]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
Loading…
Reference in New Issue