[feat][assistant][I5EWNX][I5EWO5]Add datatype for Mod and Erf

This commit is contained in:
hj-ustb 2022-11-02 15:28:32 +08:00
parent 62361cdd21
commit abdf77cb59
6 changed files with 24 additions and 13 deletions

View File

@ -10,11 +10,11 @@ mindspore.ops.erf
erf(x)=\frac{2} {\sqrt{\pi}} \int\limits_0^{x} e^{-t^{2}} dt erf(x)=\frac{2} {\sqrt{\pi}} \int\limits_0^{x} e^{-t^{2}} dt
参数: 参数:
- **x** (Tensor) - 高斯误差函数的输入Tensor。维度必须小于8数据类型必须为float16或float32 - **x** (Tensor) - 高斯误差函数的输入Tensor。维度必须小于8数据类型必须为float16,float32或float64
返回: 返回:
Tensor具有与 `x` 相同的数据类型和shape。 Tensor具有与 `x` 相同的数据类型和shape。
异常: 异常:
- **TypeError** - `x` 不是Tensor。 - **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型既不是float16也不是float32 - **TypeError** - `x` 的数据类型既不是float16,float32也不是float64

View File

@ -1020,12 +1020,20 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}}, SpecializeArithFunc<double>}}},
{kMod, {kMod,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), {{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int>}, SpecializeArithFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpecializeArithFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpecializeArithFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>}, SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<int64_t>}}}, SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{kFloorMod, {kFloorMod,
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), {{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>}, SpecializeArithFunc<int64_t>},

View File

@ -611,14 +611,14 @@ bool ArithmeticSelfCpuKernelFunc::RunFunc(const std::vector<kernel::AddressPtr>
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat32) { if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernelFloat16(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) { } else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs); LaunchKernel<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex64) { } else if (dtype_ == kNumberTypeComplex64) {
LaunchKernelComplex<std::complex<float>>(inputs, outputs); LaunchKernelComplex<std::complex<float>>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex128) { } else if (dtype_ == kNumberTypeComplex128) {
LaunchKernelComplex<std::complex<double>>(inputs, outputs); LaunchKernelComplex<std::complex<double>>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernelFloat16(inputs, outputs);
} else if (dtype_ == kNumberTypeInt8) { } else if (dtype_ == kNumberTypeInt8) {
LaunchKernel<int8_t>(inputs, outputs); LaunchKernel<int8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt16) { } else if (dtype_ == kNumberTypeInt16) {
@ -751,7 +751,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelFloat16(const std::vector<AddressP
{prim::kPrimAsin->name(), Asin<float16>}, {prim::kPrimACos->name(), ACos<float16>}, {prim::kPrimAsin->name(), Asin<float16>}, {prim::kPrimACos->name(), ACos<float16>},
{prim::kPrimSinh->name(), Sinh<float16>}, {prim::kPrimCosh->name(), Cosh<float16>}, {prim::kPrimSinh->name(), Sinh<float16>}, {prim::kPrimCosh->name(), Cosh<float16>},
{prim::kPrimAsinh->name(), Asinh<float16>}, {prim::kPrimErfc->name(), Erfc<float16>}, {prim::kPrimAsinh->name(), Asinh<float16>}, {prim::kPrimErfc->name(), Erfc<float16>},
{prim::kPrimRsqrt->name(), Rsqrt<float16>}}; {prim::kPrimRsqrt->name(), Rsqrt<float16>}, {prim::kPrimErf->name(), Erf<float16>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_); const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) { if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
@ -1020,7 +1020,8 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
[]() { return std::make_shared<LogMKLKernelFunc>(); }}}}, []() { return std::make_shared<LogMKLKernelFunc>(); }}}},
{kErf, {kErf,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc}, {{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kErfc, {kErfc,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc}, {{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc},

View File

@ -88,7 +88,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&UnaryOpGpuKernelMod::LaunchKernel<half>}}}, &UnaryOpGpuKernelMod::LaunchKernel<half>}}},
{kErf, {kErf,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), {{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&UnaryOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&UnaryOpGpuKernelMod::LaunchKernel<float>}, &UnaryOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&UnaryOpGpuKernelMod::LaunchKernel<half>}}}, &UnaryOpGpuKernelMod::LaunchKernel<half>}}},

View File

@ -51,7 +51,7 @@ TypePtr ErfInferType(const PrimitivePtr &primitive, const std::vector<AbstractBa
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
auto infer_type = input_args[0]->BuildType(); auto infer_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(infer_type); MS_EXCEPTION_IF_NULL(infer_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim_name);

View File

@ -1864,14 +1864,14 @@ def erf(x):
Args: Args:
x (Tensor): The input tensor of Gaussian error function. Its rank must be in [0, 7] inclusive x (Tensor): The input tensor of Gaussian error function. Its rank must be in [0, 7] inclusive
and data type must be float16 or float32. and data type must be float16 float32 or float64.
Returns: Returns:
Tensor, has the same shape and dtype as the `x`. Tensor, has the same shape and dtype as the `x`.
Raises: Raises:
TypeError: If `x` is not a Tensor. TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32. TypeError: If dtype of `x` is neither float16 float32 or float64.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``