!37074 [feat] [assistant] [ops] [I4ZZRT] New GPU operator implementation, include ApplyAddSign
Merge pull request !37074 from tangyibo/master
This commit is contained in:
commit
3beefc745f
|
@ -0,0 +1,223 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T Sgn(T x) {
|
||||
return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0);
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename G>
|
||||
__global__ void ApplyAddSignKernel(const size_t size,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
const S learning_rate,
|
||||
const S alpha,
|
||||
const S sign_decay,
|
||||
const S beta,
|
||||
const G *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
accumulation[i] = (beta * accumulation[i]) + ((static_cast<T>(1.) - beta) * gradient[i]);
|
||||
T update = (alpha + (sign_decay * Sgn(gradient[i]) * Sgn(accumulation[i]))) * gradient[i];
|
||||
variable[i] = variable[i] - (learning_rate * update);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAddSignKernel(const size_t size,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
const float learning_rate,
|
||||
const float alpha,
|
||||
const float sign_decay,
|
||||
const float beta,
|
||||
const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
accumulation[i] = (beta * __half2float(accumulation[i])) +
|
||||
((static_cast<float>(1.) - beta) * __half2float(gradient[i]));
|
||||
float update = (alpha + (sign_decay * Sgn(__half2float(gradient[i])) * Sgn(__half2float(accumulation[i])))) *
|
||||
__half2float(gradient[i]);
|
||||
variable[i] = __half2float(variable[i]) - (learning_rate * update);
|
||||
variable[i] = __float2half(variable[i]);
|
||||
accumulation[i] = __float2half(accumulation[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAddSignKernel(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const float learning_rate,
|
||||
const float alpha,
|
||||
const float sign_decay,
|
||||
const float beta,
|
||||
const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
accumulation[i] = (beta * accumulation[i]) + ((static_cast<float>(1.) - beta) * __half2float(gradient[i]));
|
||||
float update = (alpha + (sign_decay * Sgn(__half2float(gradient[i])) * Sgn(accumulation[i]))) *
|
||||
__half2float(gradient[i]);
|
||||
variable[i] = variable[i] - (learning_rate * update);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAddSignKernel(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const half learning_rate,
|
||||
const half alpha,
|
||||
const half sign_decay,
|
||||
const half beta,
|
||||
const float *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
accumulation[i] = (__half2float(beta) * accumulation[i]) +
|
||||
((static_cast<float>(1.) - __half2float(beta)) * gradient[i]);
|
||||
float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(gradient[i]) * Sgn(accumulation[i]))) *
|
||||
gradient[i];
|
||||
variable[i] = variable[i] - (__half2float(learning_rate) * update);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAddSignKernel(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const half learning_rate,
|
||||
const half alpha,
|
||||
const half sign_decay,
|
||||
const half beta,
|
||||
const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
accumulation[i] = (__half2float(beta) * accumulation[i]) + ((static_cast<float>(1.) - __half2float(beta)) *
|
||||
__half2float(gradient[i]));
|
||||
float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(__half2float(gradient[i])) *
|
||||
Sgn(accumulation[i]))) * __half2float(gradient[i]);
|
||||
variable[i] = variable[i] - __half2float(learning_rate) * update;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyAddSignKernel(const size_t size,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
const half learning_rate,
|
||||
const half alpha,
|
||||
const half sign_decay,
|
||||
const half beta,
|
||||
const half *gradient) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
accumulation[i] = (__half2float(beta) * __half2float(accumulation[i])) +
|
||||
((static_cast<float>(1.) - __half2float(beta)) * __half2float(gradient[i]));
|
||||
float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(__half2float(gradient[i])) *
|
||||
Sgn(__half2float(accumulation[i])))) * __half2float(gradient[i]);
|
||||
variable[i] = __float2half(__half2float(variable[i]) - __half2float(learning_rate) * update);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename G>
|
||||
void ApplyAddSign(const size_t size,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
const S learning_rate,
|
||||
const S alpha,
|
||||
const S sign_decay,
|
||||
const S beta,
|
||||
const G *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
ApplyAddSignKernel<<< CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, variable, accumulation, learning_rate, alpha, sign_decay, beta, gradient);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<double, double, double>(const size_t size,
|
||||
double *variable,
|
||||
double *accumulation,
|
||||
const double learning_rate,
|
||||
const double alpha,
|
||||
const double sign_decay,
|
||||
const double beta,
|
||||
const double *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<float, float, float>(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const float learning_rate,
|
||||
const float alpha,
|
||||
const float sign_decay,
|
||||
const float beta,
|
||||
const float *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<float, float, half>(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const float learning_rate,
|
||||
const float alpha,
|
||||
const float sign_decay,
|
||||
const float beta,
|
||||
const half *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<float, half, float>(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const half learning_rate,
|
||||
const half alpha,
|
||||
const half sign_decay,
|
||||
const half beta,
|
||||
const float *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<float, half, half>(const size_t size,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
const half learning_rate,
|
||||
const half alpha,
|
||||
const half sign_decay,
|
||||
const half beta,
|
||||
const half *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<half, half, half>(const size_t size,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
const half learning_rate,
|
||||
const half alpha,
|
||||
const half sign_decay,
|
||||
const half beta,
|
||||
const half *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyAddSign<half, float, half>(const size_t size,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
const float learning_rate,
|
||||
const float alpha,
|
||||
const float sign_decay,
|
||||
const float beta,
|
||||
const half *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T, typename S, typename G>
|
||||
CUDA_LIB_EXPORT void ApplyAddSign(const size_t size,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
const S learning_rate,
|
||||
const S alpha,
|
||||
const S sign_decay,
|
||||
const S beta,
|
||||
const G *gradient,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_
|
|
@ -0,0 +1,226 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh"
|
||||
#include "abstract/utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/curand.h"
|
||||
#include "mindspore/core/ops/apply_add_sign.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool ApplyAddSignGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAddSign>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' dose not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
t_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
s_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
|
||||
g_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex6).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
int ApplyAddSignGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> variable_shape_ = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> learning_rate_shape_ = std::vector<int64_t>(
|
||||
inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> gradient_shape_ = std::vector<int64_t>(inputs.at(kIndex6)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex6)->GetDeviceShapeAdaptively().end());
|
||||
t_elements_ = std::accumulate(variable_shape_.begin(), variable_shape_.end(), 1, std::multiplies<size_t>());
|
||||
s_elements_ = std::accumulate(learning_rate_shape_.begin(), learning_rate_shape_.end(), 1, std::multiplies<size_t>());
|
||||
g_elements_ = std::accumulate(gradient_shape_.begin(), gradient_shape_.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (t_elements_ == 0 || s_elements_ == 0 || g_elements_ == 0);
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
size_t variable_size_ = t_elements_ * t_size_;
|
||||
size_t accumulation_size_ = t_elements_ * t_size_;
|
||||
size_t learning_rate_size_ = s_elements_ * s_size_;
|
||||
size_t alpha_size_ = s_elements_ * s_size_;
|
||||
size_t sign_decay_size_ = s_elements_ * s_size_;
|
||||
size_t beta_size_ = s_elements_ * s_size_;
|
||||
size_t gradient_size_ = g_elements_ * g_size_;
|
||||
input_size_list_.emplace_back(variable_size_);
|
||||
input_size_list_.emplace_back(accumulation_size_);
|
||||
input_size_list_.emplace_back(learning_rate_size_);
|
||||
input_size_list_.emplace_back(alpha_size_);
|
||||
input_size_list_.emplace_back(sign_decay_size_);
|
||||
input_size_list_.emplace_back(beta_size_);
|
||||
input_size_list_.emplace_back(gradient_size_);
|
||||
output_size_list_.emplace_back(variable_size_);
|
||||
output_size_list_.emplace_back(accumulation_size_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void ApplyAddSignGpuKernelMod::ResetResource() noexcept {
|
||||
t_elements_ = 0;
|
||||
s_elements_ = 0;
|
||||
g_elements_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename G>
|
||||
bool ApplyAddSignGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
||||
S *learning_rate = GetDeviceAddress<S>(inputs, 2);
|
||||
S *alpha = GetDeviceAddress<S>(inputs, 3);
|
||||
S *sign_decay = GetDeviceAddress<S>(inputs, 4);
|
||||
S *beta = GetDeviceAddress<S>(inputs, 5);
|
||||
G *gradient = GetDeviceAddress<G>(inputs, 6);
|
||||
T *variable_out = GetDeviceAddress<T>(outputs, 0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
|
||||
S learning_rate_0 = 0.;
|
||||
S alpha_0 = 0.;
|
||||
S sign_decay_0 = 0.;
|
||||
S beta_0 = 0.;
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(&learning_rate_0, learning_rate, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpy learning_rate failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&alpha_0, alpha, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpy alpha failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(&sign_decay_0, sign_decay, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpy sign_decay failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&beta_0, beta, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpy beta failed");
|
||||
ApplyAddSign(t_elements_, variable, accumulation, learning_rate_0, alpha_0, sign_decay_0, beta_0, gradient,
|
||||
device_id_, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(variable_out, variable, outputs.at(kIndex0)->size, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(accumulation_out, accumulation, outputs.at(kIndex1)->size, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, ApplyAddSignGpuKernelMod::ApplyAddSignFunc>> ApplyAddSignGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<double, double, double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<float, float, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<half, half, half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<half, float, half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<float, float, half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<float, half, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ApplyAddSignGpuKernelMod::LaunchKernel<float, half, half>}};
|
||||
|
||||
std::vector<KernelAttr> ApplyAddSignGpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ApplyAddSignFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAddSign, ApplyAddSignGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include "mindspore/core/ops/apply_add_sign.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ApplyAddSignGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
ApplyAddSignGpuKernelMod() { ResetResource(); }
|
||||
~ApplyAddSignGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
stream_ptr_ = stream_ptr;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
void ResetResource() noexcept;
|
||||
|
||||
private:
|
||||
template <typename T, typename S, typename G>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using ApplyAddSignFunc = std::function<bool(ApplyAddSignGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t t_size_{1};
|
||||
size_t s_size_{1};
|
||||
size_t g_size_{1};
|
||||
size_t t_elements_;
|
||||
size_t s_elements_;
|
||||
size_t g_elements_;
|
||||
bool is_null_input_{false};
|
||||
void *stream_ptr_{nullptr};
|
||||
ApplyAddSignFunc kernel_func_{};
|
||||
std::optional<bool> is_input_dynamic_shape_{};
|
||||
static std::vector<std::pair<KernelAttr, ApplyAddSignFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_
|
|
@ -100,7 +100,7 @@ TuplePtr ApplyAddSignInferType(const PrimitivePtr &prim, const std::vector<Abstr
|
|||
auto sign_decay_type = input_args[kInputIndex4]->BuildType();
|
||||
auto beta_type = input_args[kInputIndex5]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex6]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert(std::make_pair("var_type", var_type));
|
||||
(void)args.insert(std::make_pair("m_type", m_type));
|
||||
|
|
|
@ -190,6 +190,7 @@ constexpr auto kGatherDGrad = "GatherDGrad";
|
|||
constexpr auto kGatherDGradV2 = "GatherDGradV2";
|
||||
|
||||
// NN
|
||||
constexpr auto kApplyAddSign = "ApplyAddSign";
|
||||
constexpr auto kAdaptiveMaxPool3D = "AdaptiveMaxPool3D";
|
||||
constexpr auto kFractionalMaxPool3DWithFixedKsize = "FractionalMaxPool3DWithFixedKsize";
|
||||
constexpr auto kFractionalMaxPool3DGradWithFixedKsize = "FractionalMaxPool3DGradWithFixedKsize";
|
||||
|
|
|
@ -6503,7 +6503,7 @@ class ApplyAddSign(Primitive):
|
|||
RuntimeError: If the data type of `var`, `accum` and `grad` conversion of Parameter is not supported.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, var_np, accum_np):
|
||||
super(Net, self).__init__()
|
||||
self.apply_addsign = P.ApplyAddSign()
|
||||
self.var = Parameter(Tensor(var_np), name="var")
|
||||
self.accum = Parameter(Tensor(accum_np), name="m")
|
||||
|
||||
def construct(self, lr, alpha, sign_decay, beta, grad):
|
||||
z = self.apply_addsign(self.var, self.accum, lr, alpha, sign_decay, beta, grad)
|
||||
return z
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_addsign_graph_float32():
|
||||
"""
|
||||
Feature: ApplyAddSign gpu kernel.
|
||||
Description: test the ApplyAddSign.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)
|
||||
accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float32)
|
||||
grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)
|
||||
expect_accum_np = 0.9 * accum_np + (1.0 - 0.9) * grident_np
|
||||
expect_update = (1.0 + 0.99 * np.sign(grident_np) * np.sign(expect_accum_np)) * grident_np
|
||||
expect_var_np = var_np - (0.001 * expect_update)
|
||||
net = Net(var_np, accum_np)
|
||||
lr = Tensor(0.001, mstype.float32)
|
||||
alpha = Tensor(1.0, mstype.float32)
|
||||
sign_decay = Tensor(0.99, mstype.float32)
|
||||
beta = Tensor(0.9, mstype.float32)
|
||||
grad = Tensor(grident_np)
|
||||
out = net(lr, alpha, sign_decay, beta, grad)
|
||||
res_var_mindspore = out[0].asnumpy()
|
||||
res_accum_mindspore = out[1].asnumpy()
|
||||
eps = np.array([1e-6 for i in range(4)]).reshape(2, 2)
|
||||
assert np.all(expect_var_np - res_var_mindspore < eps)
|
||||
assert np.all(expect_accum_np - res_accum_mindspore < eps)
|
Loading…
Reference in New Issue