[feat][assistant][I5EWNX][I5EWO5]Add datatype for Mod and Erf

This commit is contained in:
hj-ustb 2022-11-02 15:28:32 +08:00
parent 62361cdd21
commit abdf77cb59
6 changed files with 24 additions and 13 deletions

View File

@ -10,11 +10,11 @@ mindspore.ops.erf
erf(x)=\frac{2} {\sqrt{\pi}} \int\limits_0^{x} e^{-t^{2}} dt
参数:
- **x** (Tensor) - 高斯误差函数的输入Tensor。维度必须小于8数据类型必须为float16或float32
- **x** (Tensor) - 高斯误差函数的输入Tensor。维度必须小于8数据类型必须为float16,float32或float64
返回:
Tensor具有与 `x` 相同的数据类型和shape。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型既不是float16也不是float32
- **TypeError** - `x` 的数据类型既不是float16,float32也不是float64

View File

@ -1020,12 +1020,20 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{kMod,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpecializeArithFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpecializeArithFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>}}},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{kFloorMod,
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},

View File

@ -611,14 +611,14 @@ bool ArithmeticSelfCpuKernelFunc::RunFunc(const std::vector<kernel::AddressPtr>
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernelFloat16(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex64) {
LaunchKernelComplex<std::complex<float>>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex128) {
LaunchKernelComplex<std::complex<double>>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernelFloat16(inputs, outputs);
} else if (dtype_ == kNumberTypeInt8) {
LaunchKernel<int8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt16) {
@ -751,7 +751,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelFloat16(const std::vector<AddressP
{prim::kPrimAsin->name(), Asin<float16>}, {prim::kPrimACos->name(), ACos<float16>},
{prim::kPrimSinh->name(), Sinh<float16>}, {prim::kPrimCosh->name(), Cosh<float16>},
{prim::kPrimAsinh->name(), Asinh<float16>}, {prim::kPrimErfc->name(), Erfc<float16>},
{prim::kPrimRsqrt->name(), Rsqrt<float16>}};
{prim::kPrimRsqrt->name(), Rsqrt<float16>}, {prim::kPrimErf->name(), Erf<float16>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
@ -1020,7 +1020,8 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
[]() { return std::make_shared<LogMKLKernelFunc>(); }}}},
{kErf,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kErfc,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc},

View File

@ -88,7 +88,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
{kErf,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&UnaryOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&UnaryOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},

View File

@ -51,7 +51,7 @@ TypePtr ErfInferType(const PrimitivePtr &primitive, const std::vector<AbstractBa
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
auto infer_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(infer_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim_name);

View File

@ -1864,14 +1864,14 @@ def erf(x):
Args:
x (Tensor): The input tensor of Gaussian error function. Its rank must be in [0, 7] inclusive
and data type must be float16 or float32.
and data type must be float16 float32 or float64.
Returns:
Tensor, has the same shape and dtype as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
TypeError: If dtype of `x` is neither float16 float32 or float64.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``