!38689 [assistant][ops][I4ZZRR] New GPU operator implementation, include ApplyAdagradV2
Merge pull request !38689 from 康渊瑞/ApplyAdagradV2
This commit is contained in:
commit
1e61164151
|
@ -0,0 +1,242 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T SqrtFunc(T input) {
|
||||
return sqrt(input);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ half SqrtFunc(half input) {
|
||||
return hsqrt(input);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, T *variable, T *accumulation,
|
||||
const S *learning_rate, const T *gradient) {
|
||||
T grad = static_cast<T>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation,
|
||||
const half *learning_rate, const half *gradient) {
|
||||
half grad = static_cast<half>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation,
|
||||
const float *learning_rate, const half *gradient) {
|
||||
half grad = static_cast<half>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= __float2half(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, half *variable, half *accumulation,
|
||||
const double *learning_rate, const half *gradient) {
|
||||
half grad = static_cast<half>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= __float2half(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, double *variable, double *accumulation,
|
||||
const half *learning_rate, const double *gradient) {
|
||||
double grad = static_cast<double>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= __half2float(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, double *variable, double *accumulation,
|
||||
const float *learning_rate, const double *gradient) {
|
||||
double grad = static_cast<double>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, float *variable, float *accumulation,
|
||||
const half *learning_rate, const float *gradient) {
|
||||
float grad = static_cast<float>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= __half2float(learning_rate[0]) * grad / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel(const size_t size, const float epsilon, float *variable, float *accumulation,
|
||||
const double *learning_rate, const float *gradient) {
|
||||
float grad = static_cast<float>(0);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
grad = gradient[i];
|
||||
accumulation[i] += grad * grad;
|
||||
variable[i] -= learning_rate[0] * grad / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, T *variable, T *accumulation,
|
||||
const S *learning_rate, const T *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation,
|
||||
const half *learning_rate, const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation,
|
||||
const float *learning_rate, const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, half *variable, half *accumulation,
|
||||
const double *learning_rate, const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + __float2half(epsilon)));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, double *variable, double *accumulation,
|
||||
const half *learning_rate, const double *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, double *variable, double *accumulation,
|
||||
const float *learning_rate, const double *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, float *variable, float *accumulation,
|
||||
const half *learning_rate, const float *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAdagradV2Kernel_(const size_t size, const float epsilon, float *variable, float *accumulation,
|
||||
const double *learning_rate, const float *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] -= learning_rate[0] * gradient[i] / (SqrtFunc(accumulation[i] + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void ApplyAdagradV2(const size_t size, const float epsilon, const bool update_slots, T *variable, T *accumulation,
|
||||
const S *learning_rate, const T *gradient, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
if (update_slots) {
|
||||
ApplyAdagradV2Kernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, epsilon, variable, accumulation, learning_rate, gradient);
|
||||
} else {
|
||||
ApplyAdagradV2Kernel_<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, epsilon, variable, accumulation, learning_rate, gradient);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<double, double>(const size_t size, const float epsilon,
|
||||
const bool update_slots, double *variable,
|
||||
double *accumulation, const double *learning_rate,
|
||||
const double *gradient, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<float, float>(const size_t size, const float epsilon,
|
||||
const bool update_slots, float *variable,
|
||||
float *accumulation, const float *learning_rate,
|
||||
const float *gradient, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<half, half>(const size_t size, const float epsilon,
|
||||
const bool update_slots, half *variable, half *accumulation,
|
||||
const half *learning_rate, const half *gradient,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<float, half>(const size_t size, const float epsilon,
|
||||
const bool update_slots, float *variable, float *accumulation,
|
||||
const half *learning_rate, const float *gradient,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<half, float>(const size_t size, const float epsilon,
|
||||
const bool update_slots, half *variable, half *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<half, double>(const size_t size, const float epsilon,
|
||||
const bool update_slots, half *variable, half *accumulation,
|
||||
const double *learning_rate, const half *gradient,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<double, float>(const size_t size, const float epsilon,
|
||||
const bool update_slots, double *variable,
|
||||
double *accumulation, const float *learning_rate,
|
||||
const double *gradient, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<double, half>(const size_t size, const float epsilon,
|
||||
const bool update_slots, double *variable,
|
||||
double *accumulation, const half *learning_rate,
|
||||
const double *gradient, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAdagradV2<float, double>(const size_t size, const float epsilon,
|
||||
const bool update_slots, float *variable,
|
||||
float *accumulation, const double *learning_rate,
|
||||
const float *gradient, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void ApplyAdagradV2(const size_t size,
|
||||
const float epsilon,
|
||||
const bool update_slots,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
const S *learning_rate,
|
||||
const T *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAGRAD_V2_IMPL_CUH_
|
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <mindspore/core/abstract/utils.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "abstract/utils.h"
|
||||
#include "mindspore/core/ops/apply_adagrad_v2.h"
|
||||
#include "plugin/device/gpu/kernel/nn/adagrad_v2_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void AdagradV2GpuKernelMod::InOutputResize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
t_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
s_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
|
||||
|
||||
std::vector<int64_t> variable_shape_ = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> accumulation_shape_ = std::vector<int64_t>(
|
||||
inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> gradient_shape_ = std::vector<int64_t>(inputs.at(kIndex3)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex3)->GetDeviceShapeAdaptively().end());
|
||||
input_elements_ = std::accumulate(variable_shape_.begin(), variable_shape_.end(), 1, std::multiplies<int64_t>());
|
||||
|
||||
is_null_input_ = (input_elements_ == 0);
|
||||
|
||||
if (is_null_input_) {
|
||||
input_size_list_.push_back(0);
|
||||
input_size_list_.push_back(0);
|
||||
input_size_list_.push_back(0);
|
||||
input_size_list_.push_back(0);
|
||||
output_size_list_.push_back(0);
|
||||
output_size_list_.push_back(0);
|
||||
return;
|
||||
}
|
||||
|
||||
variable_size_ = t_size_;
|
||||
accumulation_size_ = t_size_;
|
||||
learning_rate_size_ = s_size_;
|
||||
gradient_size_ = t_size_;
|
||||
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(variable_shape_.size()); i++) {
|
||||
variable_size_ *= variable_shape_[i];
|
||||
}
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(accumulation_shape_.size()); i++) {
|
||||
accumulation_size_ *= accumulation_shape_[i];
|
||||
}
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(gradient_shape_.size()); i++) {
|
||||
gradient_size_ *= gradient_shape_[i];
|
||||
}
|
||||
input_size_list_.push_back(variable_size_);
|
||||
input_size_list_.push_back(accumulation_size_);
|
||||
input_size_list_.push_back(learning_rate_size_);
|
||||
input_size_list_.push_back(gradient_size_);
|
||||
output_size_list_.push_back(variable_size_);
|
||||
output_size_list_.push_back(accumulation_size_);
|
||||
}
|
||||
|
||||
bool AdagradV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAdagradV2>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
epsilon_ = kernel_ptr_->get_epsilon();
|
||||
update_slots_ = kernel_ptr_->get_update_slots();
|
||||
constexpr int INPUT_NUM = 4;
|
||||
if (inputs.size() != INPUT_NUM) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 4, but got " << inputs.size();
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' dose not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
InOutputResize(base_operator, inputs, outputs);
|
||||
outputs_ = outputs;
|
||||
return true;
|
||||
}
|
||||
|
||||
int AdagradV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
kernel_ptr_ = base_operator;
|
||||
InOutputResize(base_operator, inputs, outputs);
|
||||
outputs_ = outputs;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool AdagradV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *variable = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
S *learning_rate = GetDeviceAddress<S>(inputs, kIndex2);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, kIndex3);
|
||||
T *variable_out = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, kIndex1);
|
||||
ApplyAdagradV2(size_t(inputs[0]->size / sizeof(T)), epsilon_, update_slots_, variable, accumulation, learning_rate,
|
||||
gradient, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AdagradV2GpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list;
|
||||
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ApplyAdagradV2Func> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, AdagradV2GpuKernelMod::ApplyAdagradV2Func>> AdagradV2GpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<double, double>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<float, float>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<half, half>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<half, float>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<float, half>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<float, double>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<half, double>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<double, float>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&AdagradV2GpuKernelMod::LaunchKernel<double, half>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAdagradV2, AdagradV2GpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAGRAD_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAGRAD_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "mindspore/core/ops/apply_adagrad_v2.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/curand.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_v2_impl.cuh"
|
||||
// #include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adagrad_v2_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdagradV2GpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
AdagradV2GpuKernelMod() = default;
|
||||
~AdagradV2GpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
stream_ptr_ = stream_ptr;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
void ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
t_size_ = DEFAULT_SIZE_;
|
||||
s_size_ = DEFAULT_SIZE_;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using ApplyAdagradV2Func =
|
||||
std::function<bool(AdagradV2GpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
void InOutputResize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs);
|
||||
|
||||
private:
|
||||
constexpr static int64_t DEFAULT_SIZE_ = 4;
|
||||
|
||||
float epsilon_;
|
||||
bool update_slots_;
|
||||
|
||||
int64_t variable_size_{0};
|
||||
int64_t accumulation_size_{0};
|
||||
int64_t learning_rate_size_{0};
|
||||
int64_t gradient_size_{0};
|
||||
bool is_null_input_{false};
|
||||
std::string kernel_name_{"ApplyAdagradV2"};
|
||||
|
||||
int64_t t_size_{4};
|
||||
int64_t s_size_{4};
|
||||
int64_t input_elements_;
|
||||
BaseOperatorPtr kernel_ptr_{nullptr};
|
||||
std::vector<KernelTensorPtr> outputs_ = {};
|
||||
|
||||
ApplyAdagradV2Func kernel_func_{};
|
||||
void *stream_ptr_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, ApplyAdagradV2Func>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_ADAGRAD_V2_GPU_KERNEL_H
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
@ -37,9 +38,9 @@ abstract::TupleShapePtr ApplyAdagradV2InferShape(const PrimitivePtr &primitive,
|
|||
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
|
||||
// lr must be a scalar [Number, Tensor]
|
||||
const int64_t kShapeSize_ = 1;
|
||||
auto lr_shape_rank = SizeToLong(lr_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr's rank'", lr_shape_rank, kLessEqual, kShapeSize_, primitive->name());
|
||||
if (lr_shape_rank == 1) {
|
||||
auto lr_shape_size = lr_shape.size();
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr's rank'", lr_shape_size, kLessEqual, kShapeSize_, primitive->name());
|
||||
if (lr_shape_size == 1) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape[0], kEqual, kShapeSize_, primitive->name());
|
||||
}
|
||||
// var, accum and grad must have the same shape
|
||||
|
@ -47,8 +48,8 @@ abstract::TupleShapePtr ApplyAdagradV2InferShape(const PrimitivePtr &primitive,
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, accum_shape});
|
||||
}
|
||||
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
||||
(void)same_shape_args_map.insert(std::make_pair("accum", accum_shape));
|
||||
(void)same_shape_args_map.insert(std::make_pair("grad", grad_shape));
|
||||
same_shape_args_map.insert({"accum", accum_shape});
|
||||
same_shape_args_map.insert({"grad", grad_shape});
|
||||
for (auto &elem : same_shape_args_map) {
|
||||
if (*elem.second != *var_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', evaluator arg '" << elem.first
|
||||
|
@ -64,21 +65,26 @@ TuplePtr ApplyAdagradV2InferType(const PrimitivePtr &prim, const std::vector<Abs
|
|||
auto accum_type = input_args[kInputIndex1]->BuildType();
|
||||
auto lr_type = input_args[kInputIndex2]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex3]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
const std::set<TypePtr> valid_types = {kFloat};
|
||||
// var, accum, grad must have the same type
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert(std::make_pair("var_type", var_type));
|
||||
(void)args.insert(std::make_pair("accum_type", accum_type));
|
||||
(void)args.insert(std::make_pair("grad_type", grad_type));
|
||||
(void)args.insert({"var_type", var_type});
|
||||
(void)args.insert({"accum_type", accum_type});
|
||||
(void)args.insert({"grad_type", grad_type});
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
|
||||
// lr mustr be a scalar
|
||||
std::map<std::string, TypePtr> args_lr;
|
||||
(void)args_lr.insert(std::make_pair("lr_type", lr_type));
|
||||
(void)args_lr.insert({"lr_type", lr_type});
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim->name());
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ApplyAdagradV2::Init(const float epsilon, const bool update_slots) {
|
||||
set_epsilon(epsilon);
|
||||
set_update_slots(update_slots);
|
||||
}
|
||||
|
||||
float ApplyAdagradV2::get_epsilon() const {
|
||||
auto value_ptr = this->GetAttr(kEpsilon);
|
||||
return GetValue<float>(value_ptr);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,8 +34,7 @@ class MIND_API ApplyAdagradV2 : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(ApplyAdagradV2);
|
||||
ApplyAdagradV2() : BaseOperator(kNameApplyAdagradV2) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||
|
||||
/// \brief Set epsilon, A small value (float) added for numerical stability.
|
||||
void Init(float epsilon, bool update_slots = true);
|
||||
void set_epsilon(const float epsilon);
|
||||
/// \brief Get epsilon.
|
||||
///
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
eps_f64 = np.array([1e-5 for i in range(4)]).reshape(2, 2)
|
||||
eps_f32 = np.array([1e-4 for i in range(4)]).reshape(2, 2)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, var_np, accum_np, epsilon=1e-6, update_slots=True):
|
||||
super(Net, self).__init__()
|
||||
self.apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=epsilon, update_slots=update_slots)
|
||||
self.var = Parameter(Tensor(var_np), name="var")
|
||||
self.accum = Parameter(Tensor(accum_np), name="accum")
|
||||
|
||||
def construct(self, lr, grad):
|
||||
z = self.apply_adagrad_v2(self.var, self.accum, lr, grad)
|
||||
return z
|
||||
|
||||
|
||||
def main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots):
|
||||
lr = Tensor(lr_np)
|
||||
grad = Tensor(grident_np)
|
||||
|
||||
# expect
|
||||
if update_slots:
|
||||
expect_accum_np = accum_np + grident_np * grident_np
|
||||
else:
|
||||
expect_accum_np = accum_np
|
||||
expect_var_np = var_np - lr_np * grident_np / np.sqrt(expect_accum_np + epsilon_np)
|
||||
|
||||
net = Net(var_np, accum_np, epsilon_np, update_slots)
|
||||
out = net(lr, grad)
|
||||
res_var_mindspore = out[0].asnumpy()
|
||||
res_accum_mindspore = out[1].asnumpy()
|
||||
|
||||
return (expect_var_np, res_var_mindspore), (expect_accum_np, res_accum_mindspore)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_adagradv2_fff():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test fff
|
||||
Expectation: just test
|
||||
"""
|
||||
var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)
|
||||
accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float32)
|
||||
|
||||
lr_np = np.array(0.001).astype(np.float32)
|
||||
epsilon_np = 1e-6
|
||||
|
||||
update_slots = True
|
||||
|
||||
grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)
|
||||
|
||||
var, accum = main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots)
|
||||
|
||||
assert np.all(abs(accum[0] - accum[1]) < eps_f32)
|
||||
assert np.all(abs(var[0] - var[1]) < eps_f32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_adagradv2_ddd():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test ddd
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float64)
|
||||
accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float64)
|
||||
|
||||
lr_np = np.array(0.001).astype(np.float64)
|
||||
epsilon_np = 1e-6
|
||||
update_slots = True
|
||||
|
||||
grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float64)
|
||||
|
||||
var, accum = main_test(var_np, accum_np, lr_np, grident_np, epsilon_np, update_slots)
|
||||
assert np.all(abs(accum[0] - accum[1]) < eps_f64)
|
||||
assert np.all(abs(var[0] - var[1]) < eps_f64)
|
Loading…
Reference in New Issue