diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cu new file mode 100644 index 00000000000..6e954fabab6 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cu @@ -0,0 +1,242 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh" +#include "include/cuda_fp16.h" + +template +__device__ __forceinline__ T SqrtFunc(T input) { + return sqrt(input); +} + +template <> +__device__ __forceinline__ half SqrtFunc(half input) { + return hsqrt(input); +} + +template +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, T *variable, T *accumulation, + const S *learning_rate, const T *gradient) { + T grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation, + const half *learning_rate, const half *gradient) { + half grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon))); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation, + const float *learning_rate, const half *gradient) { + half grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= __float2half(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon))); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation, + const double *learning_rate, const half *gradient) { + half grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= __float2half(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon))); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, double *variable, double *accumulation, + const half *learning_rate, const double *gradient) { + double grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= __half2float(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, double *variable, double *accumulation, + const float *learning_rate, const double *gradient) { + double grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, float *variable, float *accumulation, + const half *learning_rate, const float *gradient) { + float grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= __half2float(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, float *variable, float *accumulation, + const double *learning_rate, const float *gradient) { + float grad = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + grad = gradient[i]; + accumulation[i] += grad * grad; + variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, T *variable, T *accumulation, + const S *learning_rate, const T *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation, + const half *learning_rate, const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon))); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation, + const float *learning_rate, const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= __float2half(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon))); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation, + const double *learning_rate, const half *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= __float2half(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon))); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, double *variable, double *accumulation, + const half *learning_rate, const double *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= __half2float(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, double *variable, double *accumulation, + const float *learning_rate, const double *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, float *variable, float *accumulation, + const half *learning_rate, const float *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= __half2float(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template <> +__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, float *variable, float *accumulation, + const double *learning_rate, const float *gradient) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon)); + } +} + +template +void ApplyAdagradV2(const size_t size, const float epsilon, const bool update_slots, T *variable, T *accumulation, + const S *learning_rate, const T *gradient, const uint32_t &device_id, cudaStream_t cuda_stream) { + if (update_slots) { + ApplyAdagradV2Kernel<<>>( + size, epsilon, variable, accumulation, learning_rate, gradient); + } else { + ApplyAdagradV2Kernel_<<>>( + size, epsilon, variable, accumulation, learning_rate, gradient); + } +} + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, double *variable, + double *accumulation, const double *learning_rate, + const double *gradient, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, float *variable, + float *accumulation, const float *learning_rate, + const float *gradient, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, half *variable, half *accumulation, + const half *learning_rate, const half *gradient, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, float *variable, float *accumulation, + const half *learning_rate, const float *gradient, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, half *variable, half *accumulation, + const float *learning_rate, const half *gradient, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, half *variable, half *accumulation, + const double *learning_rate, const half *gradient, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, double *variable, + double *accumulation, const float *learning_rate, + const double *gradient, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, double *variable, + double *accumulation, const half *learning_rate, + const double *gradient, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, const float epsilon, + const bool update_slots, float *variable, + float *accumulation, const double *learning_rate, + const float *gradient, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh new file mode 100644 index 00000000000..8463029464f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size, + const float epsilon, + const bool update_slots, + T *variable, + T *accumulation, + const S *learning_rate, + const T *gradient, + const uint32_t &device_id, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.cc new file mode 100644 index 00000000000..9db53a675c3 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "abstract/utils.h" +#include "mindspore/core/ops/apply_adagrad_v2.h" +#include "plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh" + +namespace mindspore { +namespace kernel { +void AdagradV2GpuKernelMod::InOutputResize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + t_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + s_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first); + + std::vector variable_shape_ = std::vector(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex0)->GetDeviceShapeAdaptively().end()); + std::vector accumulation_shape_ = std::vector( + inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex1)->GetDeviceShapeAdaptively().end()); + std::vector gradient_shape_ = std::vector(inputs.at(kIndex3)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex3)->GetDeviceShapeAdaptively().end()); + input_elements_ = std::accumulate(variable_shape_.begin(), variable_shape_.end(), 1, std::multiplies()); + + is_null_input_ = (input_elements_ == 0); + + if (is_null_input_) { + input_size_list_.push_back(0); + input_size_list_.push_back(0); + input_size_list_.push_back(0); + input_size_list_.push_back(0); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + return; + } + + variable_size_ = t_size_; + accumulation_size_ = t_size_; + learning_rate_size_ = s_size_; + gradient_size_ = t_size_; + + for (int64_t i = 0; i < static_cast(variable_shape_.size()); i++) { + variable_size_ *= variable_shape_[i]; + } + for (int64_t i = 0; i < static_cast(accumulation_shape_.size()); i++) { + accumulation_size_ *= accumulation_shape_[i]; + } + for (int64_t i = 0; i < static_cast(gradient_shape_.size()); i++) { + gradient_size_ *= gradient_shape_[i]; + } + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(gradient_size_); + output_size_list_.push_back(variable_size_); + output_size_list_.push_back(accumulation_size_); +} + +bool AdagradV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr_ = std::dynamic_pointer_cast(base_operator); + kernel_name_ = kernel_ptr_->name(); + epsilon_ = kernel_ptr_->get_epsilon(); + update_slots_ = kernel_ptr_->get_update_slots(); + constexpr int INPUT_NUM = 4; + if (inputs.size() != INPUT_NUM) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 4, but got " << inputs.size(); + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' dose not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + InOutputResize(base_operator, inputs, outputs); + outputs_ = outputs; + return true; +} + +int AdagradV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + kernel_ptr_ = base_operator; + InOutputResize(base_operator, inputs, outputs); + outputs_ = outputs; + return KRET_OK; +} + +template +bool AdagradV2GpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *variable = GetDeviceAddress(inputs, kIndex0); + T *accumulation = GetDeviceAddress(inputs, kIndex1); + S *learning_rate = GetDeviceAddress(inputs, kIndex2); + T *gradient = GetDeviceAddress(inputs, kIndex3); + T *variable_out = GetDeviceAddress(outputs, kIndex0); + T *accumulation_out = GetDeviceAddress(outputs, kIndex1); + ApplyAdagradV2(size_t(inputs[0]->size / sizeof(T)), epsilon_, update_slots_, variable, accumulation, learning_rate, + gradient, device_id_, reinterpret_cast(stream_ptr_)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpyAsync output failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpyAsync output failed"); + return true; +} + +std::vector AdagradV2GpuKernelMod::GetOpSupport() { + static std::vector support_list; + + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +std::vector> AdagradV2GpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &AdagradV2GpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &AdagradV2GpuKernelMod::LaunchKernel}}; + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAdagradV2, AdagradV2GpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.h new file mode 100644 index 00000000000..f4a51aafe56 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAGRAD_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAGRAD_V2_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/apply_adagrad_v2.h" +#include "kernel/common_utils.h" +#include "include/curand.h" +#include "abstract/utils.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh" +// #include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adagrad_v2_helper.h" + +namespace mindspore { +namespace kernel { +class AdagradV2GpuKernelMod : public NativeGpuKernelMod { + public: + AdagradV2GpuKernelMod() = default; + ~AdagradV2GpuKernelMod() override = default; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + stream_ptr_ = stream_ptr; + return kernel_func_(this, inputs, workspace, outputs); + } + + std::vector GetOpSupport() override; + + void ResetResource() noexcept { + is_null_input_ = false; + t_size_ = DEFAULT_SIZE_; + s_size_ = DEFAULT_SIZE_; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using ApplyAdagradV2Func = + std::function &, + const std::vector &, const std::vector &)>; + void InOutputResize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs); + + private: + constexpr static int64_t DEFAULT_SIZE_ = 4; + + float epsilon_; + bool update_slots_; + + int64_t variable_size_{0}; + int64_t accumulation_size_{0}; + int64_t learning_rate_size_{0}; + int64_t gradient_size_{0}; + bool is_null_input_{false}; + std::string kernel_name_{"ApplyAdagradV2"}; + + int64_t t_size_{4}; + int64_t s_size_{4}; + int64_t input_elements_; + BaseOperatorPtr kernel_ptr_{nullptr}; + std::vector outputs_ = {}; + + ApplyAdagradV2Func kernel_func_{}; + void *stream_ptr_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_ADAGRAD_V2_GPU_KERNEL_H diff --git a/mindspore/core/ops/apply_adagrad_v2.cc b/mindspore/core/ops/apply_adagrad_v2.cc index 0ee6327759d..d34edabefb2 100644 --- a/mindspore/core/ops/apply_adagrad_v2.cc +++ b/mindspore/core/ops/apply_adagrad_v2.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include +#include #include "abstract/ops/primitive_infer_map.h" #include "ops/op_utils.h" @@ -37,9 +38,9 @@ abstract::TupleShapePtr ApplyAdagradV2InferShape(const PrimitivePtr &primitive, auto grad_shape_ptr = grad_shape->cast(); // lr must be a scalar [Number, Tensor] const int64_t kShapeSize_ = 1; - auto lr_shape_rank = SizeToLong(lr_shape.size()); - (void)CheckAndConvertUtils::CheckInteger("lr's rank'", lr_shape_rank, kLessEqual, kShapeSize_, primitive->name()); - if (lr_shape_rank == 1) { + auto lr_shape_size = lr_shape.size(); + (void)CheckAndConvertUtils::CheckInteger("lr's rank'", lr_shape_size, kLessEqual, kShapeSize_, primitive->name()); + if (lr_shape_size == 1) { (void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape[0], kEqual, kShapeSize_, primitive->name()); } // var, accum and grad must have the same shape @@ -47,8 +48,8 @@ abstract::TupleShapePtr ApplyAdagradV2InferShape(const PrimitivePtr &primitive, return std::make_shared(std::vector{var_shape, accum_shape}); } std::map same_shape_args_map; - (void)same_shape_args_map.insert(std::make_pair("accum", accum_shape)); - (void)same_shape_args_map.insert(std::make_pair("grad", grad_shape)); + same_shape_args_map.insert({"accum", accum_shape}); + same_shape_args_map.insert({"grad", grad_shape}); for (auto &elem : same_shape_args_map) { if (*elem.second != *var_shape) { MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', evaluator arg '" << elem.first @@ -64,21 +65,26 @@ TuplePtr ApplyAdagradV2InferType(const PrimitivePtr &prim, const std::vectorBuildType(); auto lr_type = input_args[kInputIndex2]->BuildType(); auto grad_type = input_args[kInputIndex3]->BuildType(); - const std::set valid_types = {kFloat16, kFloat32}; + const std::set valid_types = {kFloat}; // var, accum, grad must have the same type std::map args; - (void)args.insert(std::make_pair("var_type", var_type)); - (void)args.insert(std::make_pair("accum_type", accum_type)); - (void)args.insert(std::make_pair("grad_type", grad_type)); + (void)args.insert({"var_type", var_type}); + (void)args.insert({"accum_type", accum_type}); + (void)args.insert({"grad_type", grad_type}); (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); // lr mustr be a scalar std::map args_lr; - (void)args_lr.insert(std::make_pair("lr_type", lr_type)); + (void)args_lr.insert({"lr_type", lr_type}); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim->name()); return std::make_shared(std::vector{var_type, accum_type}); } } // namespace +void ApplyAdagradV2::Init(const float epsilon, const bool update_slots) { + set_epsilon(epsilon); + set_update_slots(update_slots); +} + float ApplyAdagradV2::get_epsilon() const { auto value_ptr = this->GetAttr(kEpsilon); return GetValue(value_ptr); diff --git a/mindspore/core/ops/apply_adagrad_v2.h b/mindspore/core/ops/apply_adagrad_v2.h index 8b3a3006662..ebb2f195417 100644 --- a/mindspore/core/ops/apply_adagrad_v2.h +++ b/mindspore/core/ops/apply_adagrad_v2.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,8 +34,7 @@ class MIND_API ApplyAdagradV2 : public BaseOperator { public: MIND_API_BASE_MEMBER(ApplyAdagradV2); ApplyAdagradV2() : BaseOperator(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); } - - /// \brief Set epsilon, A small value (float) added for numerical stability. + void Init(float epsilon, bool update_slots = true); void set_epsilon(const float epsilon); /// \brief Get epsilon. /// diff --git a/tests/st/ops/gpu/test_adagrad_v2_op.py b/tests/st/ops/gpu/test_adagrad_v2_op.py new file mode 100644 index 00000000000..99fdacf773b --- /dev/null +++ b/tests/st/ops/gpu/test_adagrad_v2_op.py @@ -0,0 +1,107 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + +eps_f64 = np.array([1e-5 for i in range(4)]).reshape(2, 2) +eps_f32 = np.array([1e-4 for i in range(4)]).reshape(2, 2) + + +class Net(nn.Cell): + def __init__(self, var_np, accum_np, epsilon=1e-6, update_slots=True): + super(Net, self).__init__() + self.apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=epsilon, update_slots=update_slots) + self.var = Parameter(Tensor(var_np), name="var") + self.accum = Parameter(Tensor(accum_np), name="accum") + + def construct(self, lr, grad): + z = self.apply_adagrad_v2(self.var, self.accum, lr, grad) + return z + + +def main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots): + lr = Tensor(lr_np) + grad = Tensor(grident_np) + + # expect + if update_slots: + expect_accum_np = accum_np + grident_np * grident_np + else: + expect_accum_np = accum_np + expect_var_np = var_np - lr_np * grident_np / np.sqrt(expect_accum_np + epsilon_np) + + net = Net(var_np, accum_np, epsilon_np, update_slots) + out = net(lr, grad) + res_var_mindspore = out[0].asnumpy() + res_accum_mindspore = out[1].asnumpy() + + return (expect_var_np, res_var_mindspore), (expect_accum_np, res_accum_mindspore) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_apply_adagradv2_fff(): + """ + Feature: None + Description: basic test fff + Expectation: just test + """ + var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32) + accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float32) + + lr_np = np.array(0.001).astype(np.float32) + epsilon_np = 1e-6 + + update_slots = True + + grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32) + + var, accum = main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots) + + assert np.all(abs(accum[0] - accum[1]) < eps_f32) + assert np.all(abs(var[0] - var[1]) < eps_f32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_apply_adagradv2_ddd(): + """ + Feature: None + Description: basic test ddd + Expectation: just test + """ + + var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float64) + accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float64) + + lr_np = np.array(0.001).astype(np.float64) + epsilon_np = 1e-6 + update_slots = True + + grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float64) + + var, accum = main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots) + assert np.all(abs(accum[0] - accum[1]) < eps_f64) + assert np.all(abs(var[0] - var[1]) < eps_f64)