!35243 Refactor the Activation and ActivationGrad for GPU.

Merge pull request !35243 from liqiliang/mishgrad-gpu
This commit is contained in:
i-robot 2022-06-01 02:36:04 +00:00 committed by Gitee
commit 2838edf7a5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 567 additions and 380 deletions

View File

@ -131,6 +131,22 @@ class NativeGpuKernelMod : public GpuKernelMod {
protected:
virtual void InitResource() {}
// choose the suitable datatype for cudnn/cublas
inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
auto type = kCudnnDtypeMap.find(Type);
if (type == kCudnnDtypeMap.end()) {
MS_EXCEPTION(TypeError) << Type << " is not supported.";
}
return type->second;
}
inline cudaDataType_t GetCudaDataType(const std::string &Type) {
auto type = kCudaDtypeMap.find(Type);
if (type == kCudaDtypeMap.end()) {
MS_EXCEPTION(TypeError) << Type << " is not supported.";
}
return type->second;
}
uint32_t device_id_;
static mindspore::HashMap<std::string, std::vector<KernelAttr>> support_map_;
};

View File

@ -272,7 +272,7 @@ bool UnaryOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::
kernel_name_ = base_operator->name();
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in" << kernel::Map2Str(kernel_attr_map_) << ", but got "
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_) << ", but got "
<< kernel_name_;
return false;
}
@ -308,7 +308,7 @@ int UnaryOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
std::vector<KernelAttr> UnaryOpGpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in" << kernel::Map2Str(kernel_attr_map_) << ", but got "
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_) << ", but got "
<< kernel_name_;
return std::vector<KernelAttr>{};
}

View File

@ -15,27 +15,165 @@
*/
#include "plugin/device/gpu/kernel/nn/activation_gpu_kernel.h"
#include <memory>
#include "ops/elu.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationFwdGpuKernelMod, half)
namespace {
constexpr auto kReLU6 = "ReLU6";
constexpr auto kTanh = "Tanh";
constexpr auto kElu = "Elu";
constexpr auto kSigmoid = "Sigmoid";
} // namespace
MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationFwdGpuKernelMod, half)
std::map<std::string, std::vector<std::pair<KernelAttr, ActivationFwdGpuKernelMod::ActivationFunc>>>
ActivationFwdGpuKernelMod::kernel_attr_map_ = {
{kReLU6,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationFwdGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationFwdGpuKernelMod::LaunchKernel<half>}}},
{kTanh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationFwdGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationFwdGpuKernelMod::LaunchKernel<half>}}},
{kElu,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationFwdGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationFwdGpuKernelMod::LaunchKernel<half>}}},
{kSigmoid,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationFwdGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationFwdGpuKernelMod::LaunchKernel<half>}}}};
MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationFwdGpuKernelMod, half)
bool ActivationFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'Activation', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
<< ", but got " << kernel_name_;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = kernel_attr_map_.at(kernel_name_)[index].second;
return true;
}
MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationFwdGpuKernelMod, half)
int ActivationFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
static const std::map<std::string, cudnnActivationMode_t> activation_mode_map = {
{kReLU6, CUDNN_ACTIVATION_CLIPPED_RELU},
{kTanh, CUDNN_ACTIVATION_TANH},
{kElu, CUDNN_ACTIVATION_ELU},
{kSigmoid, CUDNN_ACTIVATION_SIGMOID}};
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
size_t input_num = inputs.size();
if (input_num != 1) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 1, but got " << input_num;
return KRET_RESIZE_FAILED;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
input_shape_.clear();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
size_t input_element_num = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
is_null_input_ = (input_element_num == 0);
if (is_null_input_) {
return KRET_OK;
}
auto iter = activation_mode_map.find(kernel_name_);
if (iter == activation_mode_map.end()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', only support these activations: " << kernel::Map2Str(activation_mode_map) << ", but got "
<< kernel_name_;
return KRET_RESIZE_FAILED;
}
mode_ = iter->second;
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&data_descriptor_),
"For 'Activation', cudnnCreateTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateActivationDescriptor(&activation_desc_),
"For 'Activation', cudnnCreateActivationDescriptor failed.");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs.at(kIndex0)->GetDtype()));
CheckTensorSize({input_shape_});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) {
auto elu_ptr = std::dynamic_pointer_cast<ops::Elu>(base_operator);
MS_EXCEPTION_IF_NULL(elu_ptr);
float alpha = elu_ptr->get_alpha();
coef = static_cast<double>(alpha);
}
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef),
"For 'Activation', cudnnSetActivationDescriptor failed.");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape_, &shape);
if (inputs.at(kIndex0)->GetFormat() == mindspore::Format::NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])),
"For 'Activation', cudnnSetTensor4dDescriptor failed.");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])),
"For 'Activation', cudnnSetTensor4dDescriptor failed.");
}
} else {
CudnnSetTensorNdDescriptor(input_shape_, data_descriptor_, cudnn_data_type_, kernel_name_);
}
return KRET_OK;
}
std::vector<KernelAttr> ActivationFwdGpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'Activation', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
<< ", but got " << kernel_name_;
return std::vector<KernelAttr>{};
}
std::vector<KernelAttr> support_list;
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ActivationFunc> &item) { return item.first; });
return support_list;
}
template <typename T>
bool ActivationFwdGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, kIndex0);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
constexpr float alpha = 1;
constexpr float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_,
input, &beta, data_descriptor_, output),
"For 'Activation', cudnnActivationForward failed.");
return true;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, ReLU6,
[]() { return std::make_shared<ActivationFwdGpuKernelMod>(kReLU6); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Tanh,
[]() { return std::make_shared<ActivationFwdGpuKernelMod>(kTanh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Elu,
[]() { return std::make_shared<ActivationFwdGpuKernelMod>(kElu); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sigmoid,
[]() { return std::make_shared<ActivationFwdGpuKernelMod>(kSigmoid); });
} // namespace kernel
} // namespace mindspore

View File

@ -17,9 +17,12 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_
#include <functional>
#include <vector>
#include <map>
#include <string>
#include <map>
#include <utility>
#include <algorithm>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
@ -27,142 +30,55 @@
namespace mindspore {
namespace kernel {
template <typename T>
class ActivationFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
constexpr auto kUnKnown = "UnKnown";
class ActivationFwdGpuKernelMod : public NativeGpuKernelMod {
public:
ActivationFwdGpuKernelMod() { ResetResource(); }
~ActivationFwdGpuKernelMod() override { DestroyResource(); }
explicit ActivationFwdGpuKernelMod(const std::string &kernel_name) : kernel_name_(kernel_name) {}
~ActivationFwdGpuKernelMod() override { DestroyResource(); };
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input,
&beta, data_descriptor_, output),
"cudnnActivationForward failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
auto node_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto iter = kernel_map.find(node_name);
if (iter == kernel_map.end()) {
MS_LOG(EXCEPTION) << "Only support these activations: ReLU6, Tanh, Elu, Sigmoid currently, but got " << node_name;
}
mode_ = iter->second;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "For '" << node_name << "', the number of inputs must be 1, but got " << input_num;
}
auto input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, node_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
CheckTensorSize({input_shape});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) {
float alpha = GetAttr<float>(kernel_node, "alpha");
coef = static_cast<double>(alpha);
}
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef),
"cudnnSetActivationDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])),
"cudnnSetTensor4dDescriptor failed");
}
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_, kernel_node_);
}
InitSizeLists();
return true;
return kernel_func_(this, inputs, outputs);
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(data_descriptor_),
"cudnnDestroyTensorDescriptor failed");
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
activation_desc_ = nullptr;
mode_ = CUDNN_ACTIVATION_SIGMOID;
data_descriptor_ = nullptr;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
cudnn_data_type_ = CUDNN_DATA_FLOAT;
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyActivationDescriptor(activation_desc_),
"For 'Activation', cudnnDestroyActivationDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(data_descriptor_),
"For 'Activation', cudnnDestroyTensorDescriptor failed.");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&data_descriptor_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateActivationDescriptor(&activation_desc_),
"cudnnCreateActivationDescriptor failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed");
output_size_ = input_size_;
}
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
{"Tanh", CUDNN_ACTIVATION_TANH},
{"Elu", CUDNN_ACTIVATION_ELU},
{"Sigmoid", CUDNN_ACTIVATION_SIGMOID}};
cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_;
cudnnActivationMode_t mode_;
cudnnTensorDescriptor_t data_descriptor_;
bool is_null_input_;
cudnnDataType_t cudnn_data_type_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using ActivationFunc = std::function<bool(ActivationFwdGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ActivationFwdGpuKernelMod::ActivationFunc>>>
kernel_attr_map_;
std::string kernel_name_{kUnKnown};
ActivationFunc kernel_func_;
std::vector<size_t> input_shape_{};
bool is_null_input_{true};
cudnnHandle_t cudnn_handle_{nullptr};
cudnnActivationDescriptor_t activation_desc_{nullptr};
cudnnActivationMode_t mode_{CUDNN_ACTIVATION_SIGMOID};
cudnnTensorDescriptor_t data_descriptor_{nullptr};
cudnnDataType_t cudnn_data_type_{CUDNN_DATA_FLOAT};
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,43 +15,169 @@
*/
#include "plugin/device/gpu/kernel/nn/activation_grad_kernel.h"
#include <memory>
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReLU6Grad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGradGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
ReLU6Grad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernelMod, half)
namespace {
constexpr auto kReLU6Grad = "ReLU6Grad";
constexpr auto kTanhGrad = "TanhGrad";
constexpr auto kEluGrad = "EluGrad";
constexpr auto kSigmoidGrad = "SigmoidGrad";
} // namespace
MS_REG_GPU_KERNEL_ONE(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGradGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernelMod, half)
std::map<std::string, std::vector<std::pair<KernelAttr, ActivationGradGpuKernelMod::ActivationGradFunc>>>
ActivationGradGpuKernelMod::kernel_attr_map_ = {
{kReLU6Grad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationGradGpuKernelMod::LaunchKernel<half>}}},
{kTanhGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationGradGpuKernelMod::LaunchKernel<half>}}},
{kEluGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationGradGpuKernelMod::LaunchKernel<half>}}},
{kSigmoidGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationGradGpuKernelMod::LaunchKernel<half>}}}};
MS_REG_GPU_KERNEL_ONE(
EluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGradGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
EluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernelMod, half)
bool ActivationGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'ActivationGrad', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
<< ", but got " << kernel_name_;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = kernel_attr_map_.at(kernel_name_)[index].second;
return true;
}
MS_REG_GPU_KERNEL_ONE(
SigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGradGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
SigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernelMod, half)
int ActivationGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
static const std::map<std::string, cudnnActivationMode_t> activation_mode_map = {
{kReLU6Grad, CUDNN_ACTIVATION_CLIPPED_RELU},
{kTanhGrad, CUDNN_ACTIVATION_TANH},
{kEluGrad, CUDNN_ACTIVATION_ELU},
{kSigmoidGrad, CUDNN_ACTIVATION_SIGMOID}};
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
size_t input_num = inputs.size();
if (input_num != 2) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
return KRET_RESIZE_FAILED;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
input_shape_.clear();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
size_t input_element_num = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
is_null_input_ = (input_element_num == 0);
if (is_null_input_) {
return KRET_OK;
}
auto iter = activation_mode_map.find(kernel_name_);
if (iter == activation_mode_map.end()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', only support these activations: " << kernel::Map2Str(activation_mode_map) << ", but got "
<< kernel_name_;
return KRET_RESIZE_FAILED;
}
mode_ = iter->second;
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&data_descriptor_),
"For 'ActivationGrad', cudnnCreateTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateActivationDescriptor(&activation_desc_),
"For 'ActivationGrad', cudnnCreateActivationDescriptor failed.");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs.at(kIndex0)->GetDtype()));
CheckTensorSize({input_shape_});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? ReLU6_UP_TURNING_POINT : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef),
"For 'ActivationGrad', cudnnSetActivationDescriptor failed.");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape_, &shape);
if (inputs.at(kIndex0)->GetFormat() == mindspore::Format::NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])),
"For 'ActivationGrad', cudnnSetTensor4dDescriptor failed.");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])),
"For 'ActivationGrad', cudnnSetTensor4dDescriptor failed.");
}
} else {
CudnnSetTensorNdDescriptor(input_shape_, data_descriptor_, cudnn_data_type_, kernel_name_);
}
return KRET_OK;
}
std::vector<KernelAttr> ActivationGradGpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'ActivationGrad', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
<< ", but got " << kernel_name_;
return std::vector<KernelAttr>{};
}
std::vector<KernelAttr> support_list;
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ActivationGradFunc> &item) { return item.first; });
return support_list;
}
template <typename T>
bool ActivationGradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *dy = nullptr;
T *y = nullptr;
if (mode_ == CUDNN_ACTIVATION_ELU || mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) {
dy = GetDeviceAddress<T>(inputs, 0);
y = GetDeviceAddress<T>(inputs, 1);
} else {
y = GetDeviceAddress<T>(inputs, 0);
dy = GetDeviceAddress<T>(inputs, 1);
}
T *dx = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
data_descriptor_, y, &beta, data_descriptor_, dx),
"For 'ActivationGrad', cudnnActivationBackward failed.");
return true;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, ReLU6Grad,
[]() { return std::make_shared<ActivationGradGpuKernelMod>(kReLU6Grad); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, TanhGrad,
[]() { return std::make_shared<ActivationGradGpuKernelMod>(kTanhGrad); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, EluGrad,
[]() { return std::make_shared<ActivationGradGpuKernelMod>(kEluGrad); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SigmoidGrad,
[]() { return std::make_shared<ActivationGradGpuKernelMod>(kSigmoidGrad); });
} // namespace kernel
} // namespace mindspore

View File

@ -17,9 +17,12 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_
#include <functional>
#include <vector>
#include <map>
#include <string>
#include <map>
#include <utility>
#include <algorithm>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
@ -27,142 +30,55 @@
namespace mindspore {
namespace kernel {
constexpr float ReLU6_UP_TURNING_POINT = 5.999999;
template <typename T>
class ActivationGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
constexpr auto kUnKnown = "UnKnown";
class ActivationGradGpuKernelMod : public NativeGpuKernelMod {
public:
ActivationGradGpuKernelMod() { ResetResource(); }
~ActivationGradGpuKernelMod() override { DestroyResource(); }
explicit ActivationGradGpuKernelMod(const std::string &kernel_name) : kernel_name_(kernel_name) {}
~ActivationGradGpuKernelMod() override { DestroyResource(); };
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *) override {
if (is_null_input_) {
return true;
}
T *dy = nullptr;
T *y = nullptr;
if (mode_ == CUDNN_ACTIVATION_ELU || mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) {
dy = GetDeviceAddress<T>(inputs, 0);
y = GetDeviceAddress<T>(inputs, 1);
} else {
y = GetDeviceAddress<T>(inputs, 0);
dy = GetDeviceAddress<T>(inputs, 1);
}
T *dx = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
data_descriptor_, y, &beta, data_descriptor_, dx),
"cudnnActivationBackward failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
auto node_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto iter = kernel_map.find(node_name);
if (iter == kernel_map.end()) {
MS_LOG(EXCEPTION) << "Only support these activations: ReLU6, Tanh, Elu, Sigmoid currently, but got " << node_name;
}
mode_ = iter->second;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "For '" << node_name << "', the number of inputs must be 2, but got " << input_num;
}
auto input_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, node_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
CheckTensorSize({input_shape});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? ReLU6_UP_TURNING_POINT : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef),
"SetActivationDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])),
"cudnnSetTensor4dDescriptor failed");
}
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_, kernel_node_);
}
InitSizeLists();
return true;
return kernel_func_(this, inputs, outputs);
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyActivationDescriptor(activation_desc_),
"cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(data_descriptor_),
"cudnnDestroyTensorDescriptor failed");
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
activation_desc_ = nullptr;
mode_ = CUDNN_ACTIVATION_SIGMOID;
data_descriptor_ = nullptr;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
cudnn_data_type_ = CUDNN_DATA_FLOAT;
input_size_ = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyActivationDescriptor(activation_desc_),
"For 'ActivationGrad', cudnnDestroyActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(data_descriptor_),
"For 'ActivationGrad', cudnnDestroyTensorDescriptor failed");
}
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&data_descriptor_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateActivationDescriptor(&activation_desc_),
"cudnnCreateActivationDescriptor failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed");
}
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
input_size_list_.push_back(input_size_);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU},
{"TanhGrad", CUDNN_ACTIVATION_TANH},
{"EluGrad", CUDNN_ACTIVATION_ELU},
{"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}};
cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_;
cudnnActivationMode_t mode_;
cudnnTensorDescriptor_t data_descriptor_;
bool is_null_input_;
cudnnDataType_t cudnn_data_type_;
size_t input_size_;
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using ActivationGradFunc = std::function<bool(ActivationGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ActivationGradGpuKernelMod::ActivationGradFunc>>>
kernel_attr_map_;
std::string kernel_name_{kUnKnown};
ActivationGradFunc kernel_func_;
std::vector<size_t> input_shape_{};
bool is_null_input_{true};
cudnnHandle_t cudnn_handle_{nullptr};
cudnnActivationDescriptor_t activation_desc_{nullptr};
cudnnActivationMode_t mode_{CUDNN_ACTIVATION_SIGMOID};
cudnnTensorDescriptor_t data_descriptor_{nullptr};
cudnnDataType_t cudnn_data_type_{CUDNN_DATA_FLOAT};
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,12 +15,63 @@
*/
#include "plugin/device/gpu/kernel/nn/softplus_gpu_kernel.h"
#include <functional>
#include <utility>
#include <algorithm>
#include <memory>
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SoftplusGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SoftplusGpuKernelMod, half)
bool SoftplusGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (kernel_name_ != prim::kPrimSoftplus->name()) {
MS_LOG(ERROR) << "For 'Softplus', the kernel name must be 'Softplus', but got " << kernel_name_;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int SoftplusGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
auto input_element_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<size_t>());
is_null_input_ = (input_element_num == 0);
return KRET_OK;
}
template <typename T>
bool SoftplusGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Softplus(input_size_list_.at(0) / sizeof(T), input_addr, output_addr, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, SoftplusGpuKernelMod::SoftplusFunc>> SoftplusGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SoftplusGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&SoftplusGpuKernelMod::LaunchKernel<half>}};
std::vector<KernelAttr> SoftplusGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SoftplusFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Softplus, SoftplusGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GPU_KERNEL_H_
#include <vector>
#include <utility>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
@ -25,52 +27,38 @@
namespace mindspore {
namespace kernel {
template <typename T>
class SoftplusGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class SoftplusGpuKernelMod : public NativeGpuKernelMod {
public:
SoftplusGpuKernelMod() : is_null_input_(false), input_size_(0) {}
SoftplusGpuKernelMod() = default;
~SoftplusGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Softplus(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
InitResource();
input_size_ = sizeof(T);
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
for (auto dim : input_shape) {
input_size_ *= dim;
}
InitSizeLists();
return true;
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, outputs);
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
bool is_null_input_;
size_t input_size_;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using SoftplusFunc = std::function<bool(SoftplusGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, SoftplusFunc>> func_list_;
SoftplusFunc kernel_func_;
bool is_null_input_{false};
void *cuda_stream_{nullptr};
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,16 +15,66 @@
*/
#include "plugin/device/gpu/kernel/nn/softplus_grad_gpu_kernel.h"
#include <functional>
#include <utility>
#include <algorithm>
#include <memory>
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
SoftplusGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SoftplusGradGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
SoftplusGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SoftplusGradGpuKernelMod, half)
bool SoftplusGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (kernel_name_ != prim::kPrimSoftplusGrad->name()) {
MS_LOG(ERROR) << "For 'SoftplusGrad', the kernel name must be 'SoftplusGrad', but got " << kernel_name_;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int SoftplusGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
auto input_element_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<size_t>());
is_null_input_ = (input_element_num == 0);
return KRET_OK;
}
template <typename T>
bool SoftplusGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
T *x_addr = GetDeviceAddress<T>(inputs, 1);
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
SoftplusGrad(input_size_list_.at(0) / sizeof(T), dy_addr, x_addr, dx_addr,
reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, SoftplusGradGpuKernelMod::SoftplusGradFunc>> SoftplusGradGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SoftplusGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&SoftplusGradGpuKernelMod::LaunchKernel<half>}};
std::vector<KernelAttr> SoftplusGradGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SoftplusGradFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SoftplusGrad, SoftplusGradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTPLUS_GRAD_KERNEL_H_
#include <vector>
#include <utility>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
@ -25,54 +27,38 @@
namespace mindspore {
namespace kernel {
template <typename T>
class SoftplusGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class SoftplusGradGpuKernelMod : public NativeGpuKernelMod {
public:
SoftplusGradGpuKernelMod() : input_size_(0) {}
SoftplusGradGpuKernelMod() = default;
~SoftplusGradGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
T *x_addr = GetDeviceAddress<T>(inputs, 1);
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
SoftplusGrad(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
InitResource();
input_size_ = sizeof(T);
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
for (auto dim : input_shape) {
input_size_ *= dim;
}
InitSizeLists();
return true;
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, outputs);
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
bool is_null_input_;
size_t input_size_;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using SoftplusGradFunc = std::function<bool(SoftplusGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, SoftplusGradFunc>> func_list_;
SoftplusGradFunc kernel_func_;
bool is_null_input_{false};
void *cuda_stream_{nullptr};
};
} // namespace kernel
} // namespace mindspore