!38689 [assistant][ops][I4ZZRR] New GPU operator implementation, include ApplyAdagradV2

Merge pull request !38689 from 康渊瑞/ApplyAdagradV2
This commit is contained in:
i-robot 2022-08-03 06:22:28 +00:00 committed by Gitee
commit 1e61164151
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 731 additions and 14 deletions

View File

@ -0,0 +1,242 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh"
#include "include/cuda_fp16.h"
template <typename T>
__device__ __forceinline__ T SqrtFunc(T input) {
return sqrt(input);
}
template <>
__device__ __forceinline__ half SqrtFunc(half input) {
return hsqrt(input);
}
template <typename T, typename S>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, T *variable, T *accumulation,
const S *learning_rate, const T *gradient) {
T grad = static_cast<T>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation,
const half *learning_rate, const half *gradient) {
half grad = static_cast<half>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation,
const float *learning_rate, const half *gradient) {
half grad = static_cast<half>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= __float2half(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation,
const double *learning_rate, const half *gradient) {
half grad = static_cast<half>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= __float2half(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, double *variable, double *accumulation,
const half *learning_rate, const double *gradient) {
double grad = static_cast<double>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= __half2float(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, double *variable, double *accumulation,
const float *learning_rate, const double *gradient) {
double grad = static_cast<double>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, float *variable, float *accumulation,
const half *learning_rate, const float *gradient) {
float grad = static_cast<float>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= __half2float(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, float *variable, float *accumulation,
const double *learning_rate, const float *gradient) {
float grad = static_cast<float>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
grad = gradient[i];
accumulation[i] += grad * grad;
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <typename T, typename S>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, T *variable, T *accumulation,
const S *learning_rate, const T *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation,
const half *learning_rate, const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation,
const float *learning_rate, const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation,
const double *learning_rate, const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, double *variable, double *accumulation,
const half *learning_rate, const double *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, double *variable, double *accumulation,
const float *learning_rate, const double *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, float *variable, float *accumulation,
const half *learning_rate, const float *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <>
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, float *variable, float *accumulation,
const double *learning_rate, const float *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
}
}
template <typename T, typename S>
void ApplyAdagradV2(const size_t size, const float epsilon, const bool update_slots, T *variable, T *accumulation,
const S *learning_rate, const T *gradient, const uint32_t &device_id, cudaStream_t cuda_stream) {
if (update_slots) {
ApplyAdagradV2Kernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size, epsilon, variable, accumulation, learning_rate, gradient);
} else {
ApplyAdagradV2Kernel_<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size, epsilon, variable, accumulation, learning_rate, gradient);
}
}
template CUDA_LIB_EXPORT void ApplyAdagradV2<double, double>(const size_t size, const float epsilon,
const bool update_slots, double *variable,
double *accumulation, const double *learning_rate,
const double *gradient, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<float, float>(const size_t size, const float epsilon,
const bool update_slots, float *variable,
float *accumulation, const float *learning_rate,
const float *gradient, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<half, half>(const size_t size, const float epsilon,
const bool update_slots, half *variable, half *accumulation,
const half *learning_rate, const half *gradient,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<float, half>(const size_t size, const float epsilon,
const bool update_slots, float *variable, float *accumulation,
const half *learning_rate, const float *gradient,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<half, float>(const size_t size, const float epsilon,
const bool update_slots, half *variable, half *accumulation,
const float *learning_rate, const half *gradient,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<half, double>(const size_t size, const float epsilon,
const bool update_slots, half *variable, half *accumulation,
const double *learning_rate, const half *gradient,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<double, float>(const size_t size, const float epsilon,
const bool update_slots, double *variable,
double *accumulation, const float *learning_rate,
const double *gradient, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<double, half>(const size_t size, const float epsilon,
const bool update_slots, double *variable,
double *accumulation, const half *learning_rate,
const double *gradient, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAdagradV2<float, double>(const size_t size, const float epsilon,
const bool update_slots, float *variable,
float *accumulation, const double *learning_rate,
const float *gradient, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size,
const float epsilon,
const bool update_slots,
T *variable,
T *accumulation,
const S *learning_rate,
const T *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_

View File

@ -0,0 +1,225 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <mindspore/core/abstract/utils.h>
#include <memory>
#include <utility>
#include <algorithm>
#include "abstract/utils.h"
#include "mindspore/core/ops/apply_adagrad_v2.h"
#include "plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh"
namespace mindspore {
namespace kernel {
void AdagradV2GpuKernelMod::InOutputResize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
t_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
s_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
std::vector<int64_t> variable_shape_ = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
std::vector<int64_t> accumulation_shape_ = std::vector<int64_t>(
inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
std::vector<int64_t> gradient_shape_ = std::vector<int64_t>(inputs.at(kIndex3)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex3)->GetDeviceShapeAdaptively().end());
input_elements_ = std::accumulate(variable_shape_.begin(), variable_shape_.end(), 1, std::multiplies<int64_t>());
is_null_input_ = (input_elements_ == 0);
if (is_null_input_) {
input_size_list_.push_back(0);
input_size_list_.push_back(0);
input_size_list_.push_back(0);
input_size_list_.push_back(0);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
return;
}
variable_size_ = t_size_;
accumulation_size_ = t_size_;
learning_rate_size_ = s_size_;
gradient_size_ = t_size_;
for (int64_t i = 0; i < static_cast<int64_t>(variable_shape_.size()); i++) {
variable_size_ *= variable_shape_[i];
}
for (int64_t i = 0; i < static_cast<int64_t>(accumulation_shape_.size()); i++) {
accumulation_size_ *= accumulation_shape_[i];
}
for (int64_t i = 0; i < static_cast<int64_t>(gradient_shape_.size()); i++) {
gradient_size_ *= gradient_shape_[i];
}
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(accumulation_size_);
input_size_list_.push_back(learning_rate_size_);
input_size_list_.push_back(gradient_size_);
output_size_list_.push_back(variable_size_);
output_size_list_.push_back(accumulation_size_);
}
bool AdagradV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAdagradV2>(base_operator);
kernel_name_ = kernel_ptr_->name();
epsilon_ = kernel_ptr_->get_epsilon();
update_slots_ = kernel_ptr_->get_update_slots();
constexpr int INPUT_NUM = 4;
if (inputs.size() != INPUT_NUM) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 4, but got " << inputs.size();
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' dose not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
InOutputResize(base_operator, inputs, outputs);
outputs_ = outputs;
return true;
}
int AdagradV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
kernel_ptr_ = base_operator;
InOutputResize(base_operator, inputs, outputs);
outputs_ = outputs;
return KRET_OK;
}
template <typename T, typename S>
bool AdagradV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *variable = GetDeviceAddress<T>(inputs, kIndex0);
T *accumulation = GetDeviceAddress<T>(inputs, kIndex1);
S *learning_rate = GetDeviceAddress<S>(inputs, kIndex2);
T *gradient = GetDeviceAddress<T>(inputs, kIndex3);
T *variable_out = GetDeviceAddress<T>(outputs, kIndex0);
T *accumulation_out = GetDeviceAddress<T>(outputs, kIndex1);
ApplyAdagradV2(size_t(inputs[0]->size / sizeof(T)), epsilon_, update_slots_, variable, accumulation, learning_rate,
gradient, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpyAsync output failed");
return true;
}
std::vector<KernelAttr> AdagradV2GpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ApplyAdagradV2Func> &pair) { return pair.first; });
return support_list;
}
std::vector<std::pair<KernelAttr, AdagradV2GpuKernelMod::ApplyAdagradV2Func>> AdagradV2GpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&AdagradV2GpuKernelMod::LaunchKernel<double, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&AdagradV2GpuKernelMod::LaunchKernel<float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&AdagradV2GpuKernelMod::LaunchKernel<half, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&AdagradV2GpuKernelMod::LaunchKernel<half, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&AdagradV2GpuKernelMod::LaunchKernel<float, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&AdagradV2GpuKernelMod::LaunchKernel<float, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&AdagradV2GpuKernelMod::LaunchKernel<half, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&AdagradV2GpuKernelMod::LaunchKernel<double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&AdagradV2GpuKernelMod::LaunchKernel<double, half>}};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAdagradV2, AdagradV2GpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,106 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAGRAD_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAGRAD_V2_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <functional>
#include <map>
#include <utility>
#include <memory>
#include <algorithm>
#include <iostream>
#include "mindspore/core/ops/apply_adagrad_v2.h"
#include "kernel/common_utils.h"
#include "include/curand.h"
#include "abstract/utils.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh"
// #include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adagrad_v2_helper.h"
namespace mindspore {
namespace kernel {
class AdagradV2GpuKernelMod : public NativeGpuKernelMod {
public:
AdagradV2GpuKernelMod() = default;
~AdagradV2GpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
stream_ptr_ = stream_ptr;
return kernel_func_(this, inputs, workspace, outputs);
}
std::vector<KernelAttr> GetOpSupport() override;
void ResetResource() noexcept {
is_null_input_ = false;
t_size_ = DEFAULT_SIZE_;
s_size_ = DEFAULT_SIZE_;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
private:
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using ApplyAdagradV2Func =
std::function<bool(AdagradV2GpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
void InOutputResize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs);
private:
constexpr static int64_t DEFAULT_SIZE_ = 4;
float epsilon_;
bool update_slots_;
int64_t variable_size_{0};
int64_t accumulation_size_{0};
int64_t learning_rate_size_{0};
int64_t gradient_size_{0};
bool is_null_input_{false};
std::string kernel_name_{"ApplyAdagradV2"};
int64_t t_size_{4};
int64_t s_size_{4};
int64_t input_elements_;
BaseOperatorPtr kernel_ptr_{nullptr};
std::vector<KernelTensorPtr> outputs_ = {};
ApplyAdagradV2Func kernel_func_{};
void *stream_ptr_{nullptr};
static std::vector<std::pair<KernelAttr, ApplyAdagradV2Func>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_ADAGRAD_V2_GPU_KERNEL_H

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@
#include <algorithm>
#include <set>
#include <utility>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
@ -37,9 +38,9 @@ abstract::TupleShapePtr ApplyAdagradV2InferShape(const PrimitivePtr &primitive,
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
// lr must be a scalar [Number, Tensor]
const int64_t kShapeSize_ = 1;
auto lr_shape_rank = SizeToLong(lr_shape.size());
(void)CheckAndConvertUtils::CheckInteger("lr's rank'", lr_shape_rank, kLessEqual, kShapeSize_, primitive->name());
if (lr_shape_rank == 1) {
auto lr_shape_size = lr_shape.size();
(void)CheckAndConvertUtils::CheckInteger("lr's rank'", lr_shape_size, kLessEqual, kShapeSize_, primitive->name());
if (lr_shape_size == 1) {
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape[0], kEqual, kShapeSize_, primitive->name());
}
// var, accum and grad must have the same shape
@ -47,8 +48,8 @@ abstract::TupleShapePtr ApplyAdagradV2InferShape(const PrimitivePtr &primitive,
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, accum_shape});
}
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
(void)same_shape_args_map.insert(std::make_pair("accum", accum_shape));
(void)same_shape_args_map.insert(std::make_pair("grad", grad_shape));
same_shape_args_map.insert({"accum", accum_shape});
same_shape_args_map.insert({"grad", grad_shape});
for (auto &elem : same_shape_args_map) {
if (*elem.second != *var_shape) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', evaluator arg '" << elem.first
@ -64,21 +65,26 @@ TuplePtr ApplyAdagradV2InferType(const PrimitivePtr &prim, const std::vector<Abs
auto accum_type = input_args[kInputIndex1]->BuildType();
auto lr_type = input_args[kInputIndex2]->BuildType();
auto grad_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
const std::set<TypePtr> valid_types = {kFloat};
// var, accum, grad must have the same type
std::map<std::string, TypePtr> args;
(void)args.insert(std::make_pair("var_type", var_type));
(void)args.insert(std::make_pair("accum_type", accum_type));
(void)args.insert(std::make_pair("grad_type", grad_type));
(void)args.insert({"var_type", var_type});
(void)args.insert({"accum_type", accum_type});
(void)args.insert({"grad_type", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
// lr mustr be a scalar
std::map<std::string, TypePtr> args_lr;
(void)args_lr.insert(std::make_pair("lr_type", lr_type));
(void)args_lr.insert({"lr_type", lr_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim->name());
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type});
}
} // namespace
void ApplyAdagradV2::Init(const float epsilon, const bool update_slots) {
set_epsilon(epsilon);
set_update_slots(update_slots);
}
float ApplyAdagradV2::get_epsilon() const {
auto value_ptr = this->GetAttr(kEpsilon);
return GetValue<float>(value_ptr);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,8 +34,7 @@ class MIND_API ApplyAdagradV2 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ApplyAdagradV2);
ApplyAdagradV2() : BaseOperator(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
/// \brief Set epsilon, A small value (float) added for numerical stability.
void Init(float epsilon, bool update_slots = true);
void set_epsilon(const float epsilon);
/// \brief Get epsilon.
///

View File

@ -0,0 +1,107 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
eps_f64 = np.array([1e-5 for i in range(4)]).reshape(2, 2)
eps_f32 = np.array([1e-4 for i in range(4)]).reshape(2, 2)
class Net(nn.Cell):
def __init__(self, var_np, accum_np, epsilon=1e-6, update_slots=True):
super(Net, self).__init__()
self.apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=epsilon, update_slots=update_slots)
self.var = Parameter(Tensor(var_np), name="var")
self.accum = Parameter(Tensor(accum_np), name="accum")
def construct(self, lr, grad):
z = self.apply_adagrad_v2(self.var, self.accum, lr, grad)
return z
def main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots):
lr = Tensor(lr_np)
grad = Tensor(grident_np)
# expect
if update_slots:
expect_accum_np = accum_np + grident_np * grident_np
else:
expect_accum_np = accum_np
expect_var_np = var_np - lr_np * grident_np / np.sqrt(expect_accum_np + epsilon_np)
net = Net(var_np, accum_np, epsilon_np, update_slots)
out = net(lr, grad)
res_var_mindspore = out[0].asnumpy()
res_accum_mindspore = out[1].asnumpy()
return (expect_var_np, res_var_mindspore), (expect_accum_np, res_accum_mindspore)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_apply_adagradv2_fff():
"""
Feature: None
Description: basic test fff
Expectation: just test
"""
var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)
accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float32)
lr_np = np.array(0.001).astype(np.float32)
epsilon_np = 1e-6
update_slots = True
grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)
var, accum = main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots)
assert np.all(abs(accum[0] - accum[1]) < eps_f32)
assert np.all(abs(var[0] - var[1]) < eps_f32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_apply_adagradv2_ddd():
"""
Feature: None
Description: basic test ddd
Expectation: just test
"""
var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float64)
accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float64)
lr_np = np.array(0.001).astype(np.float64)
epsilon_np = 1e-6
update_slots = True
grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float64)
var, accum = main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots)
assert np.all(abs(accum[0] - accum[1]) < eps_f64)
assert np.all(abs(var[0] - var[1]) < eps_f64)