From a0f56b1ba61059547ec0d9c2fd0ec66c5975862c Mon Sep 17 00:00:00 2001 From: tangyibo123 <867461411@qq.com> Date: Fri, 1 Jul 2022 16:31:48 +0800 Subject: [PATCH] add new GPU ops ApplyAddSign --- .../cuda_impl/cuda_ops/apply_add_sign_impl.cu | 223 +++++++++++++++++ .../cuda_ops/apply_add_sign_impl.cuh | 32 +++ .../kernel/nn/apply_add_sign_gpu_kernel.cc | 226 ++++++++++++++++++ .../gpu/kernel/nn/apply_add_sign_gpu_kernel.h | 84 +++++++ mindspore/core/ops/apply_add_sign.cc | 2 +- mindspore/core/ops/core_ops.h | 1 + .../python/mindspore/ops/operations/nn_ops.py | 2 +- tests/st/ops/gpu/test_apply_add_sign_op.py | 65 +++++ 8 files changed, 633 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_apply_add_sign_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cu new file mode 100644 index 00000000000..0afa2225f4a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cu @@ -0,0 +1,223 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh" +#include "include/cuda_fp16.h" + +template +__device__ __forceinline__ T Sgn(T x) { + return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); +} + +template +__global__ void ApplyAddSignKernel(const size_t size, + T *variable, + T *accumulation, + const S learning_rate, + const S alpha, + const S sign_decay, + const S beta, + const G *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + accumulation[i] = (beta * accumulation[i]) + ((static_cast(1.) - beta) * gradient[i]); + T update = (alpha + (sign_decay * Sgn(gradient[i]) * Sgn(accumulation[i]))) * gradient[i]; + variable[i] = variable[i] - (learning_rate * update); + } +} + +template <> +__global__ void ApplyAddSignKernel(const size_t size, + half *variable, + half *accumulation, + const float learning_rate, + const float alpha, + const float sign_decay, + const float beta, + const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + accumulation[i] = (beta * __half2float(accumulation[i])) + + ((static_cast(1.) - beta) * __half2float(gradient[i])); + float update = (alpha + (sign_decay * Sgn(__half2float(gradient[i])) * Sgn(__half2float(accumulation[i])))) * + __half2float(gradient[i]); + variable[i] = __half2float(variable[i]) - (learning_rate * update); + variable[i] = __float2half(variable[i]); + accumulation[i] = __float2half(accumulation[i]); + } +} + +template <> +__global__ void ApplyAddSignKernel(const size_t size, + float *variable, + float *accumulation, + const float learning_rate, + const float alpha, + const float sign_decay, + const float beta, + const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + accumulation[i] = (beta * accumulation[i]) + ((static_cast(1.) - beta) * __half2float(gradient[i])); + float update = (alpha + (sign_decay * Sgn(__half2float(gradient[i])) * Sgn(accumulation[i]))) * + __half2float(gradient[i]); + variable[i] = variable[i] - (learning_rate * update); + } +} + +template <> +__global__ void ApplyAddSignKernel(const size_t size, + float *variable, + float *accumulation, + const half learning_rate, + const half alpha, + const half sign_decay, + const half beta, + const float *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + accumulation[i] = (__half2float(beta) * accumulation[i]) + + ((static_cast(1.) - __half2float(beta)) * gradient[i]); + float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(gradient[i]) * Sgn(accumulation[i]))) * + gradient[i]; + variable[i] = variable[i] - (__half2float(learning_rate) * update); + } +} + +template <> +__global__ void ApplyAddSignKernel(const size_t size, + float *variable, + float *accumulation, + const half learning_rate, + const half alpha, + const half sign_decay, + const half beta, + const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + accumulation[i] = (__half2float(beta) * accumulation[i]) + ((static_cast(1.) - __half2float(beta)) * + __half2float(gradient[i])); + float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(__half2float(gradient[i])) * + Sgn(accumulation[i]))) * __half2float(gradient[i]); + variable[i] = variable[i] - __half2float(learning_rate) * update; + } +} + +template <> +__global__ void ApplyAddSignKernel(const size_t size, + half *variable, + half *accumulation, + const half learning_rate, + const half alpha, + const half sign_decay, + const half beta, + const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + accumulation[i] = (__half2float(beta) * __half2float(accumulation[i])) + + ((static_cast(1.) - __half2float(beta)) * __half2float(gradient[i])); + float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(__half2float(gradient[i])) * + Sgn(__half2float(accumulation[i])))) * __half2float(gradient[i]); + variable[i] = __float2half(__half2float(variable[i]) - __half2float(learning_rate) * update); + } +} + +template +void ApplyAddSign(const size_t size, + T *variable, + T *accumulation, + const S learning_rate, + const S alpha, + const S sign_decay, + const S beta, + const G *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream) { + ApplyAddSignKernel<<< CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>( + size, variable, accumulation, learning_rate, alpha, sign_decay, beta, gradient); +} + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + double *variable, + double *accumulation, + const double learning_rate, + const double alpha, + const double sign_decay, + const double beta, + const double *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + float *variable, + float *accumulation, + const float learning_rate, + const float alpha, + const float sign_decay, + const float beta, + const float *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + float *variable, + float *accumulation, + const float learning_rate, + const float alpha, + const float sign_decay, + const float beta, + const half *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + float *variable, + float *accumulation, + const half learning_rate, + const half alpha, + const half sign_decay, + const half beta, + const float *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + float *variable, + float *accumulation, + const half learning_rate, + const half alpha, + const half sign_decay, + const half beta, + const half *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + half *variable, + half *accumulation, + const half learning_rate, + const half alpha, + const half sign_decay, + const half beta, + const half *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + half *variable, + half *accumulation, + const float learning_rate, + const float alpha, + const float sign_decay, + const float beta, + const half *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh new file mode 100644 index 00000000000..a237f2e9f4a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT void ApplyAddSign(const size_t size, + T *variable, + T *accumulation, + const S learning_rate, + const S alpha, + const S sign_decay, + const S beta, + const G *gradient, + const uint32_t &device_id, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.cc new file mode 100644 index 00000000000..56fc0ac2b6b --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh" +#include "abstract/utils.h" +#include "kernel/common_utils.h" +#include "include/curand.h" +#include "mindspore/core/ops/apply_add_sign.h" + +namespace mindspore { +namespace kernel { +bool ApplyAddSignGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr_ = std::dynamic_pointer_cast(base_operator); + kernel_name_ = kernel_ptr_->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' dose not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + t_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + s_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first); + g_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex6).first); + return true; +} + +int ApplyAddSignGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector variable_shape_ = std::vector(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex0)->GetDeviceShapeAdaptively().end()); + std::vector learning_rate_shape_ = std::vector( + inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex2)->GetDeviceShapeAdaptively().end()); + std::vector gradient_shape_ = std::vector(inputs.at(kIndex6)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex6)->GetDeviceShapeAdaptively().end()); + t_elements_ = std::accumulate(variable_shape_.begin(), variable_shape_.end(), 1, std::multiplies()); + s_elements_ = std::accumulate(learning_rate_shape_.begin(), learning_rate_shape_.end(), 1, std::multiplies()); + g_elements_ = std::accumulate(gradient_shape_.begin(), gradient_shape_.end(), 1, std::multiplies()); + is_null_input_ = (t_elements_ == 0 || s_elements_ == 0 || g_elements_ == 0); + if (is_null_input_) { + return 0; + } + size_t variable_size_ = t_elements_ * t_size_; + size_t accumulation_size_ = t_elements_ * t_size_; + size_t learning_rate_size_ = s_elements_ * s_size_; + size_t alpha_size_ = s_elements_ * s_size_; + size_t sign_decay_size_ = s_elements_ * s_size_; + size_t beta_size_ = s_elements_ * s_size_; + size_t gradient_size_ = g_elements_ * g_size_; + input_size_list_.emplace_back(variable_size_); + input_size_list_.emplace_back(accumulation_size_); + input_size_list_.emplace_back(learning_rate_size_); + input_size_list_.emplace_back(alpha_size_); + input_size_list_.emplace_back(sign_decay_size_); + input_size_list_.emplace_back(beta_size_); + input_size_list_.emplace_back(gradient_size_); + output_size_list_.emplace_back(variable_size_); + output_size_list_.emplace_back(accumulation_size_); + return KRET_OK; +} + +void ApplyAddSignGpuKernelMod::ResetResource() noexcept { + t_elements_ = 0; + s_elements_ = 0; + g_elements_ = 0; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); +} + +template +bool ApplyAddSignGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + S *learning_rate = GetDeviceAddress(inputs, 2); + S *alpha = GetDeviceAddress(inputs, 3); + S *sign_decay = GetDeviceAddress(inputs, 4); + S *beta = GetDeviceAddress(inputs, 5); + G *gradient = GetDeviceAddress(inputs, 6); + T *variable_out = GetDeviceAddress(outputs, 0); + T *accumulation_out = GetDeviceAddress(outputs, 1); + S learning_rate_0 = 0.; + S alpha_0 = 0.; + S sign_decay_0 = 0.; + S beta_0 = 0.; + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(&learning_rate_0, learning_rate, s_elements_ * s_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr_)), + "cudaMemcpy learning_rate failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&alpha_0, alpha, s_elements_ * s_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr_)), + "cudaMemcpy alpha failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(&sign_decay_0, sign_decay, s_elements_ * s_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr_)), + "cudaMemcpy sign_decay failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&beta_0, beta, s_elements_ * s_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr_)), + "cudaMemcpy beta failed"); + ApplyAddSign(t_elements_, variable, accumulation, learning_rate_0, alpha_0, sign_decay_0, beta_0, gradient, + device_id_, reinterpret_cast(stream_ptr_)); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(variable_out, variable, outputs.at(kIndex0)->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpyAsync output failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(accumulation_out, accumulation, outputs.at(kIndex1)->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpyAsync output failed"); + return true; +} + +std::vector> ApplyAddSignGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &ApplyAddSignGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &ApplyAddSignGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &ApplyAddSignGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &ApplyAddSignGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &ApplyAddSignGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &ApplyAddSignGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &ApplyAddSignGpuKernelMod::LaunchKernel}}; + +std::vector ApplyAddSignGpuKernelMod::GetOpSupport() { + static std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAddSign, ApplyAddSignGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h new file mode 100644 index 00000000000..bbc5956af4c --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h @@ -0,0 +1,84 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/apply_add_sign.h" +#include "abstract/utils.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh" + +namespace mindspore { +namespace kernel { +class ApplyAddSignGpuKernelMod : public NativeGpuKernelMod { + public: + ApplyAddSignGpuKernelMod() { ResetResource(); } + ~ApplyAddSignGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + stream_ptr_ = stream_ptr; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + std::vector GetOpSupport() override; + + void ResetResource() noexcept; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using ApplyAddSignFunc = std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t t_size_{1}; + size_t s_size_{1}; + size_t g_size_{1}; + size_t t_elements_; + size_t s_elements_; + size_t g_elements_; + bool is_null_input_{false}; + void *stream_ptr_{nullptr}; + ApplyAddSignFunc kernel_func_{}; + std::optional is_input_dynamic_shape_{}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/apply_add_sign.cc b/mindspore/core/ops/apply_add_sign.cc index 7ba59f0ae36..7e5d66c9b29 100644 --- a/mindspore/core/ops/apply_add_sign.cc +++ b/mindspore/core/ops/apply_add_sign.cc @@ -100,7 +100,7 @@ TuplePtr ApplyAddSignInferType(const PrimitivePtr &prim, const std::vectorBuildType(); auto beta_type = input_args[kInputIndex5]->BuildType(); auto grad_type = input_args[kInputIndex6]->BuildType(); - const std::set valid_types = {kFloat16, kFloat32}; + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; std::map args; (void)args.insert(std::make_pair("var_type", var_type)); (void)args.insert(std::make_pair("m_type", m_type)); diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 0679523e592..b57e290e5ca 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -183,6 +183,7 @@ constexpr auto kSlice = "Slice"; constexpr auto kAffineGrid = "AffineGrid"; // NN +constexpr auto kApplyAddSign = "ApplyAddSign"; constexpr auto kAdaptiveMaxPool3D = "AdaptiveMaxPool3D"; constexpr auto kFractionalMaxPool3DWithFixedKsize = "FractionalMaxPool3DWithFixedKsize"; constexpr auto kFractionalMaxPool3DGradWithFixedKsize = "FractionalMaxPool3DGradWithFixedKsize"; diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index dbdc4e117d7..beb3fb4b0cb 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -6180,7 +6180,7 @@ class ApplyAddSign(Primitive): RuntimeError: If the data type of `var`, `accum` and `grad` conversion of Parameter is not supported. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples: >>> class Net(nn.Cell): diff --git a/tests/st/ops/gpu/test_apply_add_sign_op.py b/tests/st/ops/gpu/test_apply_add_sign_op.py new file mode 100644 index 00000000000..06c12a793ab --- /dev/null +++ b/tests/st/ops/gpu/test_apply_add_sign_op.py @@ -0,0 +1,65 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype + + +class Net(nn.Cell): + def __init__(self, var_np, accum_np): + super(Net, self).__init__() + self.apply_addsign = P.ApplyAddSign() + self.var = Parameter(Tensor(var_np), name="var") + self.accum = Parameter(Tensor(accum_np), name="m") + + def construct(self, lr, alpha, sign_decay, beta, grad): + z = self.apply_addsign(self.var, self.accum, lr, alpha, sign_decay, beta, grad) + return z + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_apply_addsign_graph_float32(): + """ + Feature: ApplyAddSign gpu kernel. + Description: test the ApplyAddSign. + Expectation: match to np benchmark. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32) + accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float32) + grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32) + expect_accum_np = 0.9 * accum_np + (1.0 - 0.9) * grident_np + expect_update = (1.0 + 0.99 * np.sign(grident_np) * np.sign(expect_accum_np)) * grident_np + expect_var_np = var_np - (0.001 * expect_update) + net = Net(var_np, accum_np) + lr = Tensor(0.001, mstype.float32) + alpha = Tensor(1.0, mstype.float32) + sign_decay = Tensor(0.99, mstype.float32) + beta = Tensor(0.9, mstype.float32) + grad = Tensor(grident_np) + out = net(lr, alpha, sign_decay, beta, grad) + res_var_mindspore = out[0].asnumpy() + res_accum_mindspore = out[1].asnumpy() + eps = np.array([1e-6 for i in range(4)]).reshape(2, 2) + assert np.all(expect_var_np - res_var_mindspore < eps) + assert np.all(expect_accum_np - res_accum_mindspore < eps)