!37074 [feat] [assistant] [ops] [I4ZZRT] New GPU operator implementation, include ApplyAddSign

Merge pull request !37074 from tangyibo/master
This commit is contained in:
i-robot 2022-07-14 08:07:30 +00:00 committed by Gitee
commit 3beefc745f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 633 additions and 2 deletions

View File

@ -0,0 +1,223 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh"
#include "include/cuda_fp16.h"
template <typename T>
__device__ __forceinline__ T Sgn(T x) {
return static_cast<T>(x != 0 ? (x > 0 ? 1 : -1) : 0);
}
template <typename T, typename S, typename G>
__global__ void ApplyAddSignKernel(const size_t size,
T *variable,
T *accumulation,
const S learning_rate,
const S alpha,
const S sign_decay,
const S beta,
const G *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
accumulation[i] = (beta * accumulation[i]) + ((static_cast<T>(1.) - beta) * gradient[i]);
T update = (alpha + (sign_decay * Sgn(gradient[i]) * Sgn(accumulation[i]))) * gradient[i];
variable[i] = variable[i] - (learning_rate * update);
}
}
template <>
__global__ void ApplyAddSignKernel(const size_t size,
half *variable,
half *accumulation,
const float learning_rate,
const float alpha,
const float sign_decay,
const float beta,
const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
accumulation[i] = (beta * __half2float(accumulation[i])) +
((static_cast<float>(1.) - beta) * __half2float(gradient[i]));
float update = (alpha + (sign_decay * Sgn(__half2float(gradient[i])) * Sgn(__half2float(accumulation[i])))) *
__half2float(gradient[i]);
variable[i] = __half2float(variable[i]) - (learning_rate * update);
variable[i] = __float2half(variable[i]);
accumulation[i] = __float2half(accumulation[i]);
}
}
template <>
__global__ void ApplyAddSignKernel(const size_t size,
float *variable,
float *accumulation,
const float learning_rate,
const float alpha,
const float sign_decay,
const float beta,
const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
accumulation[i] = (beta * accumulation[i]) + ((static_cast<float>(1.) - beta) * __half2float(gradient[i]));
float update = (alpha + (sign_decay * Sgn(__half2float(gradient[i])) * Sgn(accumulation[i]))) *
__half2float(gradient[i]);
variable[i] = variable[i] - (learning_rate * update);
}
}
template <>
__global__ void ApplyAddSignKernel(const size_t size,
float *variable,
float *accumulation,
const half learning_rate,
const half alpha,
const half sign_decay,
const half beta,
const float *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
accumulation[i] = (__half2float(beta) * accumulation[i]) +
((static_cast<float>(1.) - __half2float(beta)) * gradient[i]);
float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(gradient[i]) * Sgn(accumulation[i]))) *
gradient[i];
variable[i] = variable[i] - (__half2float(learning_rate) * update);
}
}
template <>
__global__ void ApplyAddSignKernel(const size_t size,
float *variable,
float *accumulation,
const half learning_rate,
const half alpha,
const half sign_decay,
const half beta,
const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
accumulation[i] = (__half2float(beta) * accumulation[i]) + ((static_cast<float>(1.) - __half2float(beta)) *
__half2float(gradient[i]));
float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(__half2float(gradient[i])) *
Sgn(accumulation[i]))) * __half2float(gradient[i]);
variable[i] = variable[i] - __half2float(learning_rate) * update;
}
}
template <>
__global__ void ApplyAddSignKernel(const size_t size,
half *variable,
half *accumulation,
const half learning_rate,
const half alpha,
const half sign_decay,
const half beta,
const half *gradient) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
accumulation[i] = (__half2float(beta) * __half2float(accumulation[i])) +
((static_cast<float>(1.) - __half2float(beta)) * __half2float(gradient[i]));
float update = (__half2float(alpha) + (__half2float(sign_decay) * Sgn(__half2float(gradient[i])) *
Sgn(__half2float(accumulation[i])))) * __half2float(gradient[i]);
variable[i] = __float2half(__half2float(variable[i]) - __half2float(learning_rate) * update);
}
}
template <typename T, typename S, typename G>
void ApplyAddSign(const size_t size,
T *variable,
T *accumulation,
const S learning_rate,
const S alpha,
const S sign_decay,
const S beta,
const G *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream) {
ApplyAddSignKernel<<< CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size, variable, accumulation, learning_rate, alpha, sign_decay, beta, gradient);
}
template CUDA_LIB_EXPORT void ApplyAddSign<double, double, double>(const size_t size,
double *variable,
double *accumulation,
const double learning_rate,
const double alpha,
const double sign_decay,
const double beta,
const double *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAddSign<float, float, float>(const size_t size,
float *variable,
float *accumulation,
const float learning_rate,
const float alpha,
const float sign_decay,
const float beta,
const float *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAddSign<float, float, half>(const size_t size,
float *variable,
float *accumulation,
const float learning_rate,
const float alpha,
const float sign_decay,
const float beta,
const half *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAddSign<float, half, float>(const size_t size,
float *variable,
float *accumulation,
const half learning_rate,
const half alpha,
const half sign_decay,
const half beta,
const float *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAddSign<float, half, half>(const size_t size,
float *variable,
float *accumulation,
const half learning_rate,
const half alpha,
const half sign_decay,
const half beta,
const half *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAddSign<half, half, half>(const size_t size,
half *variable,
half *accumulation,
const half learning_rate,
const half alpha,
const half sign_decay,
const half beta,
const half *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyAddSign<half, float, half>(const size_t size,
half *variable,
half *accumulation,
const float learning_rate,
const float alpha,
const float sign_decay,
const float beta,
const half *gradient,
const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T, typename S, typename G>
CUDA_LIB_EXPORT void ApplyAddSign(const size_t size,
T *variable,
T *accumulation,
const S learning_rate,
const S alpha,
const S sign_decay,
const S beta,
const G *gradient,
const uint32_t &device_id,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADD_SIGN_IMPL_CUH_

View File

@ -0,0 +1,226 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/apply_add_sign_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh"
#include "abstract/utils.h"
#include "kernel/common_utils.h"
#include "include/curand.h"
#include "mindspore/core/ops/apply_add_sign.h"
namespace mindspore {
namespace kernel {
bool ApplyAddSignGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAddSign>(base_operator);
kernel_name_ = kernel_ptr_->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' dose not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
t_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
s_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
g_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex6).first);
return true;
}
int ApplyAddSignGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> variable_shape_ = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
std::vector<int64_t> learning_rate_shape_ = std::vector<int64_t>(
inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
std::vector<int64_t> gradient_shape_ = std::vector<int64_t>(inputs.at(kIndex6)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex6)->GetDeviceShapeAdaptively().end());
t_elements_ = std::accumulate(variable_shape_.begin(), variable_shape_.end(), 1, std::multiplies<size_t>());
s_elements_ = std::accumulate(learning_rate_shape_.begin(), learning_rate_shape_.end(), 1, std::multiplies<size_t>());
g_elements_ = std::accumulate(gradient_shape_.begin(), gradient_shape_.end(), 1, std::multiplies<size_t>());
is_null_input_ = (t_elements_ == 0 || s_elements_ == 0 || g_elements_ == 0);
if (is_null_input_) {
return 0;
}
size_t variable_size_ = t_elements_ * t_size_;
size_t accumulation_size_ = t_elements_ * t_size_;
size_t learning_rate_size_ = s_elements_ * s_size_;
size_t alpha_size_ = s_elements_ * s_size_;
size_t sign_decay_size_ = s_elements_ * s_size_;
size_t beta_size_ = s_elements_ * s_size_;
size_t gradient_size_ = g_elements_ * g_size_;
input_size_list_.emplace_back(variable_size_);
input_size_list_.emplace_back(accumulation_size_);
input_size_list_.emplace_back(learning_rate_size_);
input_size_list_.emplace_back(alpha_size_);
input_size_list_.emplace_back(sign_decay_size_);
input_size_list_.emplace_back(beta_size_);
input_size_list_.emplace_back(gradient_size_);
output_size_list_.emplace_back(variable_size_);
output_size_list_.emplace_back(accumulation_size_);
return KRET_OK;
}
void ApplyAddSignGpuKernelMod::ResetResource() noexcept {
t_elements_ = 0;
s_elements_ = 0;
g_elements_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
}
template <typename T, typename S, typename G>
bool ApplyAddSignGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
S *learning_rate = GetDeviceAddress<S>(inputs, 2);
S *alpha = GetDeviceAddress<S>(inputs, 3);
S *sign_decay = GetDeviceAddress<S>(inputs, 4);
S *beta = GetDeviceAddress<S>(inputs, 5);
G *gradient = GetDeviceAddress<G>(inputs, 6);
T *variable_out = GetDeviceAddress<T>(outputs, 0);
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
S learning_rate_0 = 0.;
S alpha_0 = 0.;
S sign_decay_0 = 0.;
S beta_0 = 0.;
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(&learning_rate_0, learning_rate, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy learning_rate failed");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&alpha_0, alpha, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy alpha failed");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(&sign_decay_0, sign_decay, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy sign_decay failed");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&beta_0, beta, s_elements_ * s_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy beta failed");
ApplyAddSign(t_elements_, variable, accumulation, learning_rate_0, alpha_0, sign_decay_0, beta_0, gradient,
device_id_, reinterpret_cast<cudaStream_t>(stream_ptr_));
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(variable_out, variable, outputs.at(kIndex0)->size, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(accumulation_out, accumulation, outputs.at(kIndex1)->size, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpyAsync output failed");
return true;
}
std::vector<std::pair<KernelAttr, ApplyAddSignGpuKernelMod::ApplyAddSignFunc>> ApplyAddSignGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&ApplyAddSignGpuKernelMod::LaunchKernel<double, double, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ApplyAddSignGpuKernelMod::LaunchKernel<float, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&ApplyAddSignGpuKernelMod::LaunchKernel<half, half, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&ApplyAddSignGpuKernelMod::LaunchKernel<half, float, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ApplyAddSignGpuKernelMod::LaunchKernel<float, float, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ApplyAddSignGpuKernelMod::LaunchKernel<float, half, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ApplyAddSignGpuKernelMod::LaunchKernel<float, half, half>}};
std::vector<KernelAttr> ApplyAddSignGpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ApplyAddSignFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAddSign, ApplyAddSignGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,84 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <utility>
#include <map>
#include <iostream>
#include "mindspore/core/ops/apply_add_sign.h"
#include "abstract/utils.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_add_sign_impl.cuh"
namespace mindspore {
namespace kernel {
class ApplyAddSignGpuKernelMod : public NativeGpuKernelMod {
public:
ApplyAddSignGpuKernelMod() { ResetResource(); }
~ApplyAddSignGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
stream_ptr_ = stream_ptr;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
void ResetResource() noexcept;
private:
template <typename T, typename S, typename G>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using ApplyAddSignFunc = std::function<bool(ApplyAddSignGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
private:
size_t t_size_{1};
size_t s_size_{1};
size_t g_size_{1};
size_t t_elements_;
size_t s_elements_;
size_t g_elements_;
bool is_null_input_{false};
void *stream_ptr_{nullptr};
ApplyAddSignFunc kernel_func_{};
std::optional<bool> is_input_dynamic_shape_{};
static std::vector<std::pair<KernelAttr, ApplyAddSignFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_ADD_SIGN_GPU_KERNEL_H_

View File

@ -100,7 +100,7 @@ TuplePtr ApplyAddSignInferType(const PrimitivePtr &prim, const std::vector<Abstr
auto sign_decay_type = input_args[kInputIndex4]->BuildType();
auto beta_type = input_args[kInputIndex5]->BuildType();
auto grad_type = input_args[kInputIndex6]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> args;
(void)args.insert(std::make_pair("var_type", var_type));
(void)args.insert(std::make_pair("m_type", m_type));

View File

@ -190,6 +190,7 @@ constexpr auto kGatherDGrad = "GatherDGrad";
constexpr auto kGatherDGradV2 = "GatherDGradV2";
// NN
constexpr auto kApplyAddSign = "ApplyAddSign";
constexpr auto kAdaptiveMaxPool3D = "AdaptiveMaxPool3D";
constexpr auto kFractionalMaxPool3DWithFixedKsize = "FractionalMaxPool3DWithFixedKsize";
constexpr auto kFractionalMaxPool3DGradWithFixedKsize = "FractionalMaxPool3DGradWithFixedKsize";

View File

@ -6503,7 +6503,7 @@ class ApplyAddSign(Primitive):
RuntimeError: If the data type of `var`, `accum` and `grad` conversion of Parameter is not supported.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU``
Examples:
>>> class Net(nn.Cell):

View File

@ -0,0 +1,65 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
class Net(nn.Cell):
def __init__(self, var_np, accum_np):
super(Net, self).__init__()
self.apply_addsign = P.ApplyAddSign()
self.var = Parameter(Tensor(var_np), name="var")
self.accum = Parameter(Tensor(accum_np), name="m")
def construct(self, lr, alpha, sign_decay, beta, grad):
z = self.apply_addsign(self.var, self.accum, lr, alpha, sign_decay, beta, grad)
return z
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_apply_addsign_graph_float32():
"""
Feature: ApplyAddSign gpu kernel.
Description: test the ApplyAddSign.
Expectation: match to np benchmark.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
var_np = np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)
accum_np = np.array([[0.6, 0.5], [0.2, 0.6]]).astype(np.float32)
grident_np = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)
expect_accum_np = 0.9 * accum_np + (1.0 - 0.9) * grident_np
expect_update = (1.0 + 0.99 * np.sign(grident_np) * np.sign(expect_accum_np)) * grident_np
expect_var_np = var_np - (0.001 * expect_update)
net = Net(var_np, accum_np)
lr = Tensor(0.001, mstype.float32)
alpha = Tensor(1.0, mstype.float32)
sign_decay = Tensor(0.99, mstype.float32)
beta = Tensor(0.9, mstype.float32)
grad = Tensor(grident_np)
out = net(lr, alpha, sign_decay, beta, grad)
res_var_mindspore = out[0].asnumpy()
res_accum_mindspore = out[1].asnumpy()
eps = np.array([1e-6 for i in range(4)]).reshape(2, 2)
assert np.all(expect_var_np - res_var_mindspore < eps)
assert np.all(expect_accum_np - res_accum_mindspore < eps)