!33924 [GPU][OP] add hard shrink gpu kernel and hard shrink grad gpu kernel

Merge pull request !33924 from yangruoqi713/hshrink
This commit is contained in:
i-robot 2022-05-10 03:14:14 +00:00 committed by Gitee
commit ca4c85f2d2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 681 additions and 61 deletions

View File

@ -17,7 +17,7 @@
#include "plugin/device/cpu/kernel/hshrink_cpu_kernel.h"
#include <algorithm>
#include "mindspore/core/ops/hshrink.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -35,17 +35,18 @@ std::vector<KernelAttr> HShrinkCpuKernelMod::GetOpSupport() {
bool HShrinkCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::HShrink>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast HShrink ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
kernel_name_ = base_operator->name();
if (inputs.size() != kHShrinkInputsNum || outputs.size() != kHShrinkOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kHShrinkInputsNum << " and "
<< kHShrinkOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
return false;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::HShrink>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast HShrink ops failed!";
return false;
}
lambd_ = kernel_ptr->get_lambd();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
@ -70,13 +71,13 @@ bool HShrinkCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
MS_ERROR_IF_NULL_W_RET_VAL(output, false);
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1;
auto task = [input, output, this](size_t start, size_t end) {
const float &lambd = this->lambd_;
auto task = [input, output, &lambd](size_t start, size_t end) {
const T positive_lambd = static_cast<T>(lambd);
const T negative_lambd = static_cast<T>(-1 * lambd);
const T zero = static_cast<T>(0);
for (size_t i = start; i < end; i++) {
if (input[i] >= static_cast<T>(-1 * this->lambd_) && input[i] <= static_cast<T>(this->lambd_)) {
output[i] = static_cast<T>(0);
} else {
output[i] = input[i];
}
output[i] = (input[i] >= negative_lambd && input[i] <= positive_lambd) ? zero : input[i];
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);

View File

@ -17,12 +17,9 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HSHRINK_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HSHRINK_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {

View File

@ -17,7 +17,7 @@
#include "plugin/device/cpu/kernel/hshrink_grad_cpu_kernel.h"
#include <algorithm>
#include "mindspore/core/ops/grad/hshrink_grad.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -35,17 +35,18 @@ std::vector<KernelAttr> HShrinkGradCpuKernelMod::GetOpSupport() {
bool HShrinkGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::HShrinkGrad>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast HShrinkGrad ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
kernel_name_ = base_operator->name();
if (inputs.size() != kHShrinkGradInputsNum || outputs.size() != kHShrinkGradOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kHShrinkGradInputsNum << " and "
<< kHShrinkGradOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
return false;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::HShrinkGrad>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast HShrinkGrad ops failed!";
return false;
}
lambd_ = kernel_ptr->get_lambd();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
@ -72,13 +73,13 @@ bool HShrinkGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
MS_ERROR_IF_NULL_W_RET_VAL(dx, false);
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1;
auto task = [dy, x, dx, this](size_t start, size_t end) {
const float &lambd = this->lambd_;
auto task = [dy, x, dx, &lambd](size_t start, size_t end) {
const T positive_lambd = static_cast<T>(lambd);
const T negative_lambd = static_cast<T>(-1 * lambd);
const T zero = static_cast<T>(0);
for (size_t i = start; i < end; i++) {
if (x[i] >= static_cast<T>(-1 * this->lambd_) && x[i] <= static_cast<T>(this->lambd_)) {
dx[i] = static_cast<T>(0);
} else {
dx[i] = dy[i];
}
dx[i] = (x[i] >= negative_lambd && x[i] <= positive_lambd) ? zero : dy[i];
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);

View File

@ -17,12 +17,9 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HSHRINK_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HSHRINK_GRAD_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hshrink_impl.cuh"
#include "include/cuda_fp16.h"
template <typename T>
__global__ void HShrinkKernel(size_t size, const T *input, const float lambd, T *output) {
const T positive_lambd = static_cast<T>(lambd);
const T negative_lambd = static_cast<T>(-1 * lambd);
const T zero = static_cast<T>(0);
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output[pos] = (input[pos] >= negative_lambd && input[pos] <= positive_lambd) ? zero : input[pos];
}
}
template <typename T>
__global__ void HShrinkGradKernel(size_t size, const T *dout, const T *x, const float lambd, T *output) {
const T positive_lambd = static_cast<T>(lambd);
const T negative_lambd = static_cast<T>(-1 * lambd);
const T zero = static_cast<T>(0);
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output[pos] = (x[pos] >= negative_lambd && x[pos] <= positive_lambd) ? zero : dout[pos];
}
}
template <typename T>
void CalHShrink(const size_t &size, const T *input, const float lambd, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
HShrinkKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input, lambd, output);
}
template <typename T>
void CalHShrinkGrad(const size_t &size, const T *dout, const T *x, const float lambd, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream) {
HShrinkGradKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, dout, x, lambd,
output);
}
template CUDA_LIB_EXPORT void CalHShrink<half>(const size_t &size, const half *input, const float lambd, half *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalHShrink<float>(const size_t &size, const float *input, const float lambd,
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalHShrinkGrad<half>(const size_t &size, const half *dout, const half *x,
const float lambd, half *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalHShrinkGrad<float>(const size_t &size, const float *dout, const float *x,
const float lambd, float *output, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSHRINK_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSHRINK_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void CalHShrink(const size_t &size, const T *input, const float lambd, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CalHShrinkGrad(const size_t &size, const T *dout, const T *x, const float lambd, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSHRINK_IMPL_CUH_

View File

@ -0,0 +1,120 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/hshrink_gpu_kernel.h"
#include <algorithm>
#include <functional>
#include "mindspore/core/ops/hshrink.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hshrink_impl.cuh"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kHShrinkInputsNum = 1;
constexpr size_t kHShrinkOutputsNum = 1;
} // namespace
bool HShrinkGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.size() != kHShrinkInputsNum || outputs.size() != kHShrinkOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kHShrinkInputsNum << " and "
<< kHShrinkOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
return false;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::HShrink>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast HShrink ops failed!";
return false;
}
lambd_ = kernel_ptr->get_lambd();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
return true;
}
int HShrinkGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_INVALID_SHAPE;
}
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
if (input_elements_ == 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be greater than zero.";
return KRET_RESIZE_FAILED;
}
InitSizeLists();
return KRET_OK;
}
void HShrinkGpuKernelMod::ResetResource() noexcept {
input_elements_ = 0;
input_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void HShrinkGpuKernelMod::InitSizeLists() {
size_t input_size = input_elements_ * unit_size_;
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size);
}
template <typename T>
bool HShrinkGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, kIndex0);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
CalHShrink(input_elements_, input, lambd_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, HShrinkGpuKernelMod::HShrinkFunc>> HShrinkGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&HShrinkGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&HShrinkGpuKernelMod::LaunchKernel<float>}};
std::vector<KernelAttr> HShrinkGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, HShrinkFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, HShrink, HShrinkGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
namespace mindspore {
namespace kernel {
class HShrinkGpuKernelMod : public NativeGpuKernelMod {
public:
HShrinkGpuKernelMod() { ResetResource(); }
~HShrinkGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
void ResetResource() noexcept;
void InitSizeLists();
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using HShrinkFunc = std::function<bool(HShrinkGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
std::vector<size_t> input_shape_;
size_t unit_size_{1};
size_t input_elements_{};
void *cuda_stream_{nullptr};
float lambd_ = 0.f;
HShrinkFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, HShrinkFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GPU_KERNEL_H_

View File

@ -0,0 +1,122 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/hshrink_grad_gpu_kernel.h"
#include <algorithm>
#include <functional>
#include "mindspore/core/ops/grad/hshrink_grad.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hshrink_impl.cuh"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kHShrinkGradInputsNum = 2;
constexpr size_t kHShrinkGradOutputsNum = 1;
} // namespace
bool HShrinkGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.size() != kHShrinkGradInputsNum || outputs.size() != kHShrinkGradOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kHShrinkGradInputsNum << " and "
<< kHShrinkGradOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
return false;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::HShrinkGrad>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast HShrinkGrad ops failed!";
return false;
}
lambd_ = kernel_ptr->get_lambd();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
return true;
}
int HShrinkGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_INVALID_SHAPE;
}
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
if (input_elements_ == 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be greater than zero.";
return KRET_RESIZE_FAILED;
}
InitSizeLists();
return KRET_OK;
}
void HShrinkGradGpuKernelMod::ResetResource() noexcept {
input_elements_ = 0;
input_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void HShrinkGradGpuKernelMod::InitSizeLists() {
size_t input_size = input_elements_ * unit_size_;
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size);
}
template <typename T>
bool HShrinkGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
T *dy = GetDeviceAddress<T>(inputs, kIndex0);
T *x = GetDeviceAddress<T>(inputs, kIndex1);
T *dx = GetDeviceAddress<T>(outputs, kIndex0);
CalHShrinkGrad(input_elements_, dy, x, lambd_, dx, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, HShrinkGradGpuKernelMod::HShrinkGradFunc>> HShrinkGradGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&HShrinkGradGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&HShrinkGradGpuKernelMod::LaunchKernel<float>}};
std::vector<KernelAttr> HShrinkGradGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, HShrinkGradFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, HShrinkGrad, HShrinkGradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GRAD_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
namespace mindspore {
namespace kernel {
class HShrinkGradGpuKernelMod : public NativeGpuKernelMod {
public:
HShrinkGradGpuKernelMod() { ResetResource(); }
~HShrinkGradGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
void ResetResource() noexcept;
void InitSizeLists();
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using HShrinkGradFunc = std::function<bool(HShrinkGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
std::vector<size_t> input_shape_;
size_t unit_size_{1};
size_t input_elements_{};
void *cuda_stream_{nullptr};
float lambd_ = 0.f;
HShrinkGradFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, HShrinkGradFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GRAD_GPU_KERNEL_H_

View File

@ -2292,7 +2292,7 @@ class HShrinkGrad(Primitive):
TypeError: If dtype of `gradients` or `features` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
"""
@prim_attr_register

View File

@ -8455,7 +8455,7 @@ class HShrink(Primitive):
Tensor, the same shape and data type as the input.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
Raises:
TypeError: If `lambd` is not a float.

View File

@ -19,36 +19,50 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
def hshrink_grad_op_np_bencmark(grad, input_x, lambd):
"""
Feature: generate a hshrink grad numpy benchmark.
Description: The input shape need to match to output shape.
Expectation: match to mindspore HShrinkGrad.
"""
result = np.zeros_like(grad, dtype=grad.dtype)
for index, _ in np.ndenumerate(grad):
if input_x[index] > lambd or input_x[index] < (-1 * lambd):
result[index] = grad[index]
else:
result[index] = 0
return result
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
def test_hshrink(dtype):
@pytest.mark.parametrize("data_shape", [(3, 4), (4, 5, 6, 7)])
@pytest.mark.parametrize("lambd", [0.5])
def test_hshrink_grad(dtype, data_shape, lambd):
"""
Feature: HShrinkGrad cpu kernel
Description: test the rightness of HShrinkGrad cpu kernel
Expectation: the output[0] is same as numpy
Expectation: the output[0] is same as hshrink_grad_op_np_bencmark output
"""
class NetHShrinkGrad(nn.Cell):
def __init__(self):
super(NetHShrinkGrad, self).__init__()
self.hard_shrink_grad = G.HShrinkGrad(lambd=0.5)
self.gradients = Parameter(Tensor(np.array([[0.02979, 0.287, 0.676],
[0.2837, 0.1216, -0.6543]], dtype=dtype)), name='gradients')
self.features = Parameter(Tensor(np.array([[0.5, 1, 2.0],
[0.0533, 0.0776, -2.1233]], dtype=dtype)), name='features')
self.hard_shrink_grad = G.HShrinkGrad(lambd)
def construct(self):
return self.hard_shrink_grad(self.gradients, self.features)
def construct(self, grad, input_x):
return self.hard_shrink_grad(grad, input_x)
grad_data = np.random.random(data_shape).astype(dtype)
input_data = np.random.uniform(
low=-1, high=1, size=data_shape).astype(dtype)
benchmark_output = hshrink_grad_op_np_bencmark(
grad_data, input_data, lambd)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
hshrink_grad = NetHShrinkGrad()
output = hshrink_grad()
expect = np.array([[0, 0.287, 0.676],
[0, 0, -0.6543]], dtype=dtype)
assert np.allclose(output.asnumpy(), expect)
output = hshrink_grad(Tensor(grad_data), Tensor(input_data))
assert np.allclose(output.asnumpy(), benchmark_output)

View File

@ -19,17 +19,31 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
def hshrink_op_np_bencmark(input_x, lambd):
"""
Feature: generate a hshrink numpy benchmark.
Description: The input shape need to match to output shape.
Expectation: match to np mindspore HShrink.
"""
result = np.zeros_like(input_x, dtype=input_x.dtype)
for index, _ in np.ndenumerate(input_x):
if input_x[index] > lambd or input_x[index] < (-1 * lambd):
result[index] = input_x[index]
else:
result[index] = 0
return result
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
def test_hshrink(dtype):
@pytest.mark.parametrize("data_shape", [(3, 4), (4, 5, 6, 7)])
@pytest.mark.parametrize("lambd", [0.5])
def test_hshrink(dtype, data_shape, lambd):
"""
Feature: HShrink cpu kernel
Description: test the rightness of HShrink cpu kernel
@ -38,15 +52,15 @@ def test_hshrink(dtype):
class NetHShrink(nn.Cell):
def __init__(self):
super(NetHShrink, self).__init__()
self.hard_shrink = P.HShrink(lambd=0.5)
self.x = Parameter(Tensor(np.array([[0.5, 1, 2.0],
[0.0533, 0.0776, -2.1233]], dtype=dtype)), name='x')
self.hard_shrink = P.HShrink(lambd)
def construct(self):
return self.hard_shrink(self.x)
def construct(self, input_x):
return self.hard_shrink(input_x)
input_data = np.random.uniform(
low=-1, high=1, size=data_shape).astype(dtype)
benchmark_output = hshrink_op_np_bencmark(input_data, lambd)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
hshrink = NetHShrink()
output = hshrink()
expect = np.array([[0, 1, 2],
[0, 0, -2.1233]], dtype=dtype)
assert np.allclose(output.asnumpy(), expect)
output = hshrink(Tensor(input_data))
assert np.allclose(output.asnumpy(), benchmark_output)

View File

@ -0,0 +1,68 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
def hshrink_grad_op_np_bencmark(grad, input_x, lambd):
"""
Feature: generate a hshrink grad numpy benchmark.
Description: The input shape need to match to output shape.
Expectation: match to mindspore HShrinkGrad.
"""
result = np.zeros_like(grad, dtype=grad.dtype)
for index, _ in np.ndenumerate(grad):
if input_x[index] > lambd or input_x[index] < (-1 * lambd):
result[index] = grad[index]
else:
result[index] = 0
return result
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
@pytest.mark.parametrize("data_shape", [(3, 4), (4, 5, 6, 7)])
@pytest.mark.parametrize("lambd", [0.5])
def test_hshrink_grad(dtype, data_shape, lambd):
"""
Feature: HShrinkGrad gpu kernel
Description: test the rightness of HShrinkGrad gpu kernel
Expectation: the output[0] is same as hshrink_grad_op_np_bencmark output
"""
class NetHShrinkGrad(nn.Cell):
def __init__(self):
super(NetHShrinkGrad, self).__init__()
self.hard_shrink_grad = G.HShrinkGrad(lambd)
def construct(self, grad, input_x):
return self.hard_shrink_grad(grad, input_x)
grad_data = np.random.random(data_shape).astype(dtype)
input_data = np.random.uniform(
low=-1, high=1, size=data_shape).astype(dtype)
benchmark_output = hshrink_grad_op_np_bencmark(
grad_data, input_data, lambd)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
hshrink_grad = NetHShrinkGrad()
output = hshrink_grad(Tensor(grad_data), Tensor(input_data))
assert np.allclose(output.asnumpy(), benchmark_output)

View File

@ -0,0 +1,66 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
def hshrink_op_np_bencmark(input_x, lambd):
"""
Feature: generate a hshrink numpy benchmark.
Description: The input shape need to match to output shape.
Expectation: match to nn mindspore HShrink.
"""
result = np.zeros_like(input_x, dtype=input_x.dtype)
for index, _ in np.ndenumerate(input_x):
if input_x[index] > lambd or input_x[index] < (-1 * lambd):
result[index] = input_x[index]
else:
result[index] = 0
return result
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
@pytest.mark.parametrize("data_shape", [(3, 4), (4, 5, 6, 7)])
@pytest.mark.parametrize("lambd", [0.5])
def test_hshrink(dtype, data_shape, lambd):
"""
Feature: HShrink gpu kernel
Description: test the rightness of HShrink gpu kernel
Expectation: the output[0] is same as hshrink_op_np_bencmark output
"""
class NetHShrink(nn.Cell):
def __init__(self):
super(NetHShrink, self).__init__()
self.hard_shrink = P.HShrink(lambd)
def construct(self, input_x):
return self.hard_shrink(input_x)
input_data = np.random.uniform(
low=-1, high=1, size=data_shape).astype(dtype)
benchmark_output = hshrink_op_np_bencmark(input_data, lambd)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
hshrink = NetHShrink()
output = hshrink(Tensor(input_data))
assert np.allclose(output.asnumpy(), benchmark_output)