forked from mindspore-Ecosystem/mindspore
!41929 [feat][assistant][I5EWNU] Add datatype for Elu
Merge pull request !41929 from 胡静/elu
This commit is contained in:
commit
dd5391d37c
|
@ -22,12 +22,12 @@ mindspore.nn.ELU
|
|||
- **alpha** (`float`) - ELU的alpha值,数据类型为浮点数。默认值:1.0。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 用于计算ELU的任意维度的Tensor,数据类型为float16或float32。
|
||||
- **x** (Tensor) - 用于计算ELU的任意维度的Tensor,数据类型为float16,float32或float64。
|
||||
|
||||
输出:
|
||||
Tensor,数据类型和shape与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `alpha` 不是浮点数。
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
||||
- **TypeError** - `x` 的数据类型既不是float16,float32也不是float64。
|
||||
- **ValueError** - `alpha` 不等于1.0。
|
||||
|
|
|
@ -21,12 +21,12 @@ mindspore.ops.Elu
|
|||
- **alpha** (float) - Elu的alpha值,数据类型为浮点数。目前只支持alpha等于1.0,默认值:1.0。
|
||||
|
||||
输入:
|
||||
- **input_x** (Tensor) - 用于计算Elu的任意维度的Tensor,数据类型为float16或float32。
|
||||
- **input_x** (Tensor) - 用于计算Elu的任意维度的Tensor,数据类型为float16,float32或float64。
|
||||
|
||||
输出:
|
||||
Tensor,shape和数据类型与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `alpha` 不是float。
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
||||
- **TypeError** - `x` 的数据类型既不是float16,float32也不是float64。
|
||||
- **ValueError** - `alpha` 不等于1.0。
|
||||
|
|
|
@ -46,7 +46,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, ActivationFwdGpuKernelMo
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&ActivationFwdGpuKernelMod::LaunchKernel<half>}}},
|
||||
{kElu,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&ActivationFwdGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&ActivationFwdGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&ActivationFwdGpuKernelMod::LaunchKernel<half>}}},
|
||||
|
@ -99,7 +101,8 @@ bool ActivationFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
mode_ = mode_iter->second;
|
||||
|
||||
const auto dtype = inputs.at(kIndex0)->GetDtype();
|
||||
if ((dtype == kNumberTypeFloat64) || (dtype == kNumberTypeComplex64) || (dtype == kNumberTypeComplex128)) {
|
||||
if ((dtype == kNumberTypeFloat64 && kernel_name_ != kElu) || (dtype == kNumberTypeComplex64) ||
|
||||
(dtype == kNumberTypeComplex128)) {
|
||||
is_additional_dtype_ = true;
|
||||
}
|
||||
return true;
|
||||
|
@ -185,7 +188,6 @@ bool ActivationFwdGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
|
||||
if (is_additional_dtype_) {
|
||||
if (kernel_name_ == kTanh) {
|
||||
Tanh(input, output, input_size_list_[0] / sizeof(T), reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
|
@ -195,11 +197,21 @@ bool ActivationFwdGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
|
|||
return true;
|
||||
}
|
||||
|
||||
constexpr float alpha = 1;
|
||||
constexpr float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_,
|
||||
input, &beta, data_descriptor_, output),
|
||||
"For 'Activation', cudnnActivationForward failed.");
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
constexpr double alpha = 1.0;
|
||||
constexpr double beta = 0.0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, &beta, data_descriptor_,
|
||||
output),
|
||||
"For 'Activation', cudnnActivationForward failed.");
|
||||
} else {
|
||||
constexpr float alpha = 1.0;
|
||||
constexpr float beta = 0.0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, &beta, data_descriptor_,
|
||||
output),
|
||||
"For 'Activation', cudnnActivationForward failed.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -808,14 +808,15 @@ class Elu(Primitive):
|
|||
alpha (float): The alpha value of ELU, the data type is float. Only support '1.0' currently. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input of ELU is a Tensor of any dimension with data type of float16 or float32.
|
||||
- **input_x** (Tensor) - The input of ELU is a Tensor of any dimension with data type of
|
||||
float16, float32 or float64.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and data type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `alpha` is not a float.
|
||||
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `input_x` is neither float16, float32 nor float64.
|
||||
ValueError: If `alpha` is not equal to 1.0.
|
||||
|
||||
Supported Platforms:
|
||||
|
|
Loading…
Reference in New Issue