[feat][assistant][I5EWNX][I5EWO5]Add datatype for Mod and Erf
This commit is contained in:
parent
62361cdd21
commit
abdf77cb59
|
@ -10,11 +10,11 @@ mindspore.ops.erf
|
|||
erf(x)=\frac{2} {\sqrt{\pi}} \int\limits_0^{x} e^{-t^{2}} dt
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 高斯误差函数的输入Tensor。维度必须小于8,数据类型必须为float16或float32。
|
||||
- **x** (Tensor) - 高斯误差函数的输入Tensor。维度必须小于8,数据类型必须为float16,float32或float64。
|
||||
|
||||
返回:
|
||||
Tensor,具有与 `x` 相同的数据类型和shape。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
||||
- **TypeError** - `x` 的数据类型既不是float16,float32也不是float64。
|
||||
|
|
|
@ -1020,12 +1020,20 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{kMod,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
SpecializeArithFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
SpecializeArithFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>}}},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpecializeArithFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{kFloorMod,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
|
|
|
@ -611,14 +611,14 @@ bool ArithmeticSelfCpuKernelFunc::RunFunc(const std::vector<kernel::AddressPtr>
|
|||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernelFloat16(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeComplex64) {
|
||||
LaunchKernelComplex<std::complex<float>>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeComplex128) {
|
||||
LaunchKernelComplex<std::complex<double>>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernelFloat16(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt8) {
|
||||
LaunchKernel<int8_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt16) {
|
||||
|
@ -751,7 +751,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelFloat16(const std::vector<AddressP
|
|||
{prim::kPrimAsin->name(), Asin<float16>}, {prim::kPrimACos->name(), ACos<float16>},
|
||||
{prim::kPrimSinh->name(), Sinh<float16>}, {prim::kPrimCosh->name(), Cosh<float16>},
|
||||
{prim::kPrimAsinh->name(), Asinh<float16>}, {prim::kPrimErfc->name(), Erfc<float16>},
|
||||
{prim::kPrimRsqrt->name(), Rsqrt<float16>}};
|
||||
{prim::kPrimRsqrt->name(), Rsqrt<float16>}, {prim::kPrimErf->name(), Erf<float16>}};
|
||||
|
||||
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
|
||||
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
|
||||
|
@ -1020,7 +1020,8 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
[]() { return std::make_shared<LogMKLKernelFunc>(); }}}},
|
||||
{kErf,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kErfc,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateArithSelfFunc},
|
||||
|
|
|
@ -88,7 +88,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
|
||||
{kErf,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
|
||||
|
|
|
@ -51,7 +51,7 @@ TypePtr ErfInferType(const PrimitivePtr &primitive, const std::vector<AbstractBa
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(infer_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim_name);
|
||||
|
|
|
@ -1864,14 +1864,14 @@ def erf(x):
|
|||
|
||||
Args:
|
||||
x (Tensor): The input tensor of Gaussian error function. Its rank must be in [0, 7] inclusive
|
||||
and data type must be float16 or float32.
|
||||
and data type must be float16 float32 or float64.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `x` is neither float16 float32 or float64.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue