supplement dtypes for some meta ops

This commit is contained in:
lilinjie 2023-04-25 19:16:53 +08:00
parent 7ba401b4f9
commit 663d578c05
28 changed files with 481 additions and 73 deletions

View File

@ -55,6 +55,7 @@
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_cpu_kernel.cc" "unreadVariable"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/hccl/mux_send_ascend_kernel.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/xdivy_cpu_kernel.cc" "unreadVariable"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc" "multiCondition"
# MindData
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"

View File

@ -32,6 +32,7 @@
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_addn_cpu_kernel.cc" "whitespace/indent"
"mindspore/mindspore/ccsrc/pipeline/jit/resource.cc" "readability/fn_size"
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/binary_ops_gpu_kernel.cc" "whitespace/indent"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc" "readability/braces"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/PostProcess/Yolov4TinyMindsporePost.h" "runtime/references"

View File

@ -278,6 +278,7 @@ mindspore/mindspore/python/mindspore/ops/function/nn_func.py:conv3d
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_mask_fp32.c:GemmRowxColMaskKernelFp32
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/crop_and_resize_cpu_kernel.cc:mindspore::kernel::CropAndResizeCpuKernelMod::LaunchKernel
mindspore/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc:mindspore::device::cpu::CPUDeviceAddress::SyncHostToDevice
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc:mindspore::kernel::Cast
# AICPU migration
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bias_add_grad.cc:aicpu::BiasAddGradCpuKernel::BiasAddGradCompute

View File

@ -25,3 +25,4 @@ mindspore.nn.Sigmoid
异常:
- **TypeError** - `input_x` 的数据类型不是float16、float32、float64、complex64或complex128。
- **TypeError** - `input_x` 不是Tensor。

View File

@ -12,7 +12,7 @@ mindspore.ops.ReverseV2
- **axis** (Union[tuple(int), list(int)]) - 指定反转的轴。
输入:
- **input_x** (Tensor) - 输入需反转的任意维度的Tensor。数据类型为数值型不包括float64。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。
- **input_x** (Tensor) - 输入需反转的任意维度的Tensor。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。
输出:
Tensorshape和数据类型与输入 `input_x` 相同。

View File

@ -8,7 +8,7 @@ mindspore.ops.Sigmoid
更多参考详见 :func:`mindspore.ops.sigmoid`
输入:
- **input_x** (Tensor) - 任意维度的Tensor数据类型为float16或float32
- **input_x** (Tensor) - 任意维度的Tensor数据类型为float16、float32、float64、complex64或complex128
输出:
Tensor数据类型和shape与 `input_x` 的相同。

View File

@ -21,7 +21,7 @@
- **sorted** (bool, 可选) - 如果为 ``True`` ,则获取的元素将按值降序排序。如果为 ``False`` ,则不对获取的元素进行排序。默认值: ``True``
输入:
- **input_x** (Tensor) - 需计算的输入CPU推理数据类型必须为float16、float32或int32GPU推理数据类型必须为float16或float32。
- **input_x** (Tensor) - 需计算的输入CPU推理数据类型为NumberGPU推理数据类型必须为float16或float32。
- **k** (int) - 指定计算最大元素的数量,必须为常量。
输出:
@ -34,4 +34,4 @@
- **TypeError** - 如果 `sorted` 不是bool。
- **TypeError** - 如果 `input_x` 不是Tensor。
- **TypeError** - 如果 `k` 不是int。
- **TypeError** - 如果 `input_x` 的数据类型不是以下之一float16、float32或int32
- **TypeError** - 如果 `input_x` 的数据类型不被支持

View File

@ -985,7 +985,10 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateArithSelfFunc}}},
{kInv,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},

View File

@ -21,6 +21,7 @@
#include <string>
#include <utility>
#include <algorithm>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
@ -45,7 +46,25 @@ template <typename S, typename T>
void Cast(CastCpuKernelFunc<S, T> *content, const S *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(in[i]);
if constexpr (std::is_same_v<S, T>) {
out[i] = static_cast<T>(in[i]);
} else if constexpr (std::is_same_v<S, bool> && std::is_same_v<T, std::complex<float>>) {
out[i] = std::complex<float>(in[i] ? 1.0f : 0.0f, 0.0f);
} else if constexpr (std::is_same_v<S, bool> && std::is_same_v<T, std::complex<double>>) {
out[i] = std::complex<double>(in[i] ? 1.0 : 0.0, 0.0);
} else if constexpr (std::is_same_v<S, std::complex<float>> && std::is_same_v<T, bool>) {
out[i] = (std::real(in[i]) != 0.0f) || (std::imag(in[i]) != 0.0f);
} else if constexpr (std::is_same_v<S, std::complex<double>> && std::is_same_v<T, bool>) {
out[i] = (std::real(in[i]) != 0.0) || (std::imag(in[i]) != 0.0);
} else if constexpr ((std::is_same_v<S, std::complex<float>>) || (std::is_same_v<S, std::complex<double>>)) {
out[i] = static_cast<T>(std::real(in[i]));
} else if constexpr ((std::is_same_v<T, std::complex<float>>) || (std::is_same_v<T, std::complex<double>>)) {
double realValue = static_cast<double>(in[i]);
std::complex<double> complexValue(realValue, 0.0);
out[i] = (std::is_same_v<T, std::complex<float>>) ? static_cast<T>(complexValue) : complexValue;
} else {
out[i] = static_cast<T>(in[i]);
}
}
};
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
@ -79,6 +98,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint8_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint8_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint8_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint8_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint8_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint16_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint16_t, uint16_t>},
@ -92,6 +115,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint16_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint16_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint16_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint16_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint32_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint32_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint32_t, uint32_t>},
@ -104,6 +131,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint32_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint32_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint32_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint64_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint64_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint64_t, uint32_t>},
@ -116,6 +147,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint64_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint64_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint64_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint64_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int8_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int8_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int8_t, uint32_t>},
@ -128,6 +163,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int8_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int8_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int8_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int8_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int16_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int16_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int16_t, uint32_t>},
@ -140,6 +179,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int16_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int16_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int16_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int16_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int32_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int32_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int32_t, uint32_t>},
@ -152,6 +195,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int32_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int32_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int32_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int64_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int64_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int64_t, uint32_t>},
@ -164,6 +211,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int64_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int64_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int64_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int64_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<float16, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<float16, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<float16, uint32_t>},
@ -176,6 +227,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<float16, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<float16, double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<float16, bool>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<float16, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<float16, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<float, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<float, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<float, uint32_t>},
@ -188,6 +243,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<float, double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<float, bool>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<float, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<float, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<double, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<double, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<double, uint32_t>},
@ -200,6 +259,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<double, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<double, double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<double, bool>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<double, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<double, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<bool, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<bool, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<bool, uint32_t>},
@ -212,6 +275,66 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<bool, float>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<bool, double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateCastFunc<bool, bool>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<bool, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<bool, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<std::complex<float>, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt16),
CreateCastFunc<std::complex<float>, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt32),
CreateCastFunc<std::complex<float>, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt64),
CreateCastFunc<std::complex<float>, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt8),
CreateCastFunc<std::complex<float>, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt16),
CreateCastFunc<std::complex<float>, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32),
CreateCastFunc<std::complex<float>, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64),
CreateCastFunc<std::complex<float>, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat16),
CreateCastFunc<std::complex<float>, float16>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
CreateCastFunc<std::complex<float>, float>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat64),
CreateCastFunc<std::complex<float>, double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<std::complex<float>, bool>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<std::complex<float>, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<std::complex<float>, std::complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<std::complex<double>, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt16),
CreateCastFunc<std::complex<double>, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt32),
CreateCastFunc<std::complex<double>, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt64),
CreateCastFunc<std::complex<double>, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt8),
CreateCastFunc<std::complex<double>, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt16),
CreateCastFunc<std::complex<double>, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32),
CreateCastFunc<std::complex<double>, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64),
CreateCastFunc<std::complex<double>, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat16),
CreateCastFunc<std::complex<double>, float16>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat32),
CreateCastFunc<std::complex<double>, float>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
CreateCastFunc<std::complex<double>, double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<std::complex<double>, bool>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<std::complex<double>, std::complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<std::complex<double>, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<uint8_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16),
@ -236,6 +359,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<uint8_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<uint8_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint8_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint8_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<uint16_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
@ -260,6 +387,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<uint16_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<uint16_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint16_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint16_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<uint32_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16),
@ -284,6 +415,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<uint32_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<uint32_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint32_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint32_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<uint64_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16),
@ -308,6 +443,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<uint64_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<uint64_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<uint64_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<uint64_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<int8_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16),
@ -332,6 +471,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<int8_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<int8_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int8_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int8_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<int16_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16),
@ -356,6 +499,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<int16_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<int16_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int16_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int16_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<int32_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
@ -380,6 +527,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<int32_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<int32_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int32_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int32_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<int64_t, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
@ -404,6 +555,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<int64_t, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<int64_t, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<int64_t, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<int64_t, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<float16, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16),
@ -428,6 +583,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<float16, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<float16, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<float16, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<float16, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<float, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16),
@ -452,6 +611,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<float, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<float, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<float, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<float, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<double, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16),
@ -476,6 +639,10 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
CreateCastFunc<double, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<double, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<double, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<double, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<bool, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16),
@ -499,7 +666,67 @@ static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
CreateCastFunc<bool, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<bool, bool>}};
CreateCastFunc<bool, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeBool).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<bool, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeBool).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<bool, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<std::complex<float>, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt16),
CreateCastFunc<std::complex<float>, uint16_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt32),
CreateCastFunc<std::complex<float>, uint32_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt64),
CreateCastFunc<std::complex<float>, uint64_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt8),
CreateCastFunc<std::complex<float>, int8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt16),
CreateCastFunc<std::complex<float>, int16_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32),
CreateCastFunc<std::complex<float>, int32_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64),
CreateCastFunc<std::complex<float>, int64_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat16),
CreateCastFunc<std::complex<float>, float16>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
CreateCastFunc<std::complex<float>, float>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat64),
CreateCastFunc<std::complex<float>, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<std::complex<float>, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<std::complex<float>, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<std::complex<float>, std::complex<double>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt8),
CreateCastFunc<std::complex<double>, uint8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt16),
CreateCastFunc<std::complex<double>, uint16_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt32),
CreateCastFunc<std::complex<double>, uint32_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt64),
CreateCastFunc<std::complex<double>, uint64_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt8),
CreateCastFunc<std::complex<double>, int8_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt16),
CreateCastFunc<std::complex<double>, int16_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32),
CreateCastFunc<std::complex<double>, int32_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64),
CreateCastFunc<std::complex<double>, int64_t>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat16),
CreateCastFunc<std::complex<double>, float16>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat32),
CreateCastFunc<std::complex<double>, float>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
CreateCastFunc<std::complex<double>, double>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeBool),
CreateCastFunc<std::complex<double>, bool>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex64),
CreateCastFunc<std::complex<double>, std::complex<float>>},
{KernelAttr().AddInputAttr(kObjectTypeNumber, kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
CreateCastFunc<std::complex<double>, std::complex<double>>}};
} // namespace
bool CastCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,

View File

@ -52,11 +52,21 @@ class EltwiseCpuKernelFunc : public CpuKernelFunc {
static std::map<string, std::vector<std::pair<KernelAttr, TypeComputeFunc>>> eltwise_func_map = {
{kSigmoid,
{{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&EltwiseCpuKernelFunc<T>::SigmoidComplex},
&EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&EltwiseCpuKernelFunc<T>::SigmoidComplex},
&EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&EltwiseCpuKernelFunc<T>::SigmoidComplex}}}};
&EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), &EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), &EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&EltwiseCpuKernelFunc<T>::Sigmoid},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&EltwiseCpuKernelFunc<T>::Sigmoid}}}};
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto iter = eltwise_func_map.find(kernel_name_);
@ -96,28 +106,28 @@ class EltwiseCpuKernelFunc : public CpuKernelFunc {
using TypeComputeFunc = std::function<void(EltwiseCpuKernelFunc *, const T *input, T *output)>;
TypeComputeFunc compute_func_{nullptr};
size_t input_element_num_{0};
void SigmoidComplex(const T *input, T *output);
void Sigmoid(const T *input, T *output);
}; // namespace
template <typename T>
void EltwiseCpuKernelFunc<T>::SigmoidComplex(const T *input, T *output) {
T one_complex{1, 0};
auto task = [&input, &output, &one_complex](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = one_complex / (one_complex + exp(-input[i]));
}
};
ParallelLaunchAutoSearch(task, input_element_num_, this, &parallel_search_info_);
}
template <>
void EltwiseCpuKernelFunc<double>::SigmoidComplex(const double *input, double *output) {
auto task = [&input, &output](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = 1.0 / (1.0 + exp(-input[i]));
}
};
ParallelLaunchAutoSearch(task, input_element_num_, this, &parallel_search_info_);
void EltwiseCpuKernelFunc<T>::Sigmoid(const T *input, T *output) {
if constexpr ((std::is_same_v<T, std::complex<float>>) || (std::is_same_v<T, std::complex<double>>)) {
T one_complex{1, 0};
auto task = [&input, &output, &one_complex](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = one_complex / (one_complex + exp(-input[i]));
}
};
ParallelLaunchAutoSearch(task, input_element_num_, this, &parallel_search_info_);
} else {
T one_scalar = static_cast<double>(1);
auto task = [&input, &output, &one_scalar](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
output[i] = static_cast<T>(one_scalar / (one_scalar + exp(-static_cast<double>(input[i]))));
}
};
ParallelLaunchAutoSearch(task, input_element_num_, this, &parallel_search_info_);
}
}
struct DescParam {
@ -311,8 +321,13 @@ std::map<std::string, std::vector<std::pair<KernelAttr, EltWiseCpuKernelMod::Elt
SpecializeEltwiseFunc<complex64>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
SpecializeEltwiseFunc<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeEltwiseFunc<double>}}},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), SpecializeEltwiseFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeEltwiseFunc<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), SpecializeEltwiseFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SpecializeEltwiseFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SpecializeEltwiseFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SpecializeEltwiseFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), SpecializeEltwiseFunc<int64_t>}}},
};
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Elu, []() { return std::make_shared<EltWiseCpuKernelMod>(kElu); });

View File

@ -465,11 +465,7 @@ size_t MKLCpuKernelMod::GetSize(const dnnl::memory::desc &desc) const {
dnnl::memory::data_type MKLCpuKernelMod::GetDnnlDataType(TypeId ms_type_id) const {
static const std::map<TypeId, dnnl::memory::data_type> dnnl_data_type_map = {
{kNumberTypeFloat16, dnnl::memory::data_type::f16},
{kNumberTypeFloat32, dnnl::memory::data_type::f32},
{kNumberTypeInt32, dnnl::memory::data_type::s32},
{kNumberTypeInt8, dnnl::memory::data_type::s8},
{kNumberTypeUInt8, dnnl::memory::data_type::u8}};
{kNumberTypeFloat16, dnnl::memory::data_type::f16}, {kNumberTypeFloat32, dnnl::memory::data_type::f32}};
auto iter = dnnl_data_type_map.find(ms_type_id);
if (iter == dnnl_data_type_map.end()) {
MS_LOG(WARNING) << "Dnnl do not support data type:" << TypeIdToString(ms_type_id);

View File

@ -103,6 +103,22 @@ void ReduceSum(const T *in, T *out, size_t start, size_t end, TransposeIterator
*out = value;
}
template <>
void ReduceSum(const bool *in, bool *out, size_t start, size_t end, TransposeIterator *iter) {
bool value = *out;
if (iter != nullptr) {
for (size_t i = start; i < end; i++) {
value |= in[iter->GetPos()];
iter->GenNextPos();
}
} else {
for (size_t i = start; i < end; i++) {
value |= in[i];
}
}
*out = value;
}
template <typename T>
void ReduceMean(const T *in, T *out, size_t start, size_t end, TransposeIterator *iter) {
ReduceSum(in, out, start, end, iter);
@ -110,16 +126,30 @@ void ReduceMean(const T *in, T *out, size_t start, size_t end, TransposeIterator
template <typename T>
void ReduceProd(const T *in, T *out, size_t start, size_t end, TransposeIterator *iter) {
if (iter != nullptr) {
for (size_t i = start; i < end; i++) {
*out *= in[iter->GetPos()];
iter->GenNextPos();
if constexpr (std::is_same<T, bool>::value) {
if (iter != nullptr) {
for (size_t i = start; i < end; i++) {
*out = *out && in[iter->GetPos()];
iter->GenNextPos();
}
return;
}
return;
}
for (size_t i = start; i < end; i++) {
*out *= in[i];
for (size_t i = start; i < end; i++) {
*out = *out && in[i];
}
} else {
if (iter != nullptr) {
for (size_t i = start; i < end; i++) {
*out *= in[iter->GetPos()];
iter->GenNextPos();
}
return;
}
for (size_t i = start; i < end; i++) {
*out *= in[i];
}
}
}
@ -259,6 +289,12 @@ void ReduceCpuKernelFunc<T>::ChooseFunc(const std::string &kernel_name_) {
} else if (kernel_name_ == kReduceAny) {
reduce_type_ = ReduceFuncType::kReduceAnyType;
reduce_func_ = ReduceAny<T>;
} else if (kernel_name_ == kReduceProd) {
reduce_type_ = ReduceFuncType::kReduceProdType;
reduce_func_ = ReduceProd<T>;
} else if (kernel_name_ == kReduceSum) {
reduce_type_ = ReduceFuncType::kReduceSumType;
reduce_func_ = ReduceSum<T>;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', unsupported reduce operation for bool.";
}
@ -468,6 +504,8 @@ static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_ma
{REDUCE_CPU_REG(kNumberTypeUInt64, kNumberTypeInt32, uint64_t)},
{REDUCE_CPU_REG(kNumberTypeUInt64, kNumberTypeInt64, uint64_t)}};
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_sum_prod_mean_list = {
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt64, bool)},
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_CPU_REG(kNumberTypeFloat64, kNumberTypeInt32, double)},

View File

@ -127,10 +127,16 @@ bool ReverseV2CpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
}
std::vector<std::pair<KernelAttr, ReverseV2CpuKernelMod::ReverseV2Func>> ReverseV2CpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&ReverseV2CpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&ReverseV2CpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&ReverseV2CpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&ReverseV2CpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&ReverseV2CpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&ReverseV2CpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),

View File

@ -138,9 +138,24 @@ bool TopKCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
LaunchKernel<float>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeUInt32) {
LaunchKernel<uint32_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeInt8) {
LaunchKernel<int8_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeUInt8) {
LaunchKernel<uint8_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeInt16) {
LaunchKernel<int16_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeUInt16) {
LaunchKernel<uint16_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeUInt64) {
LaunchKernel<uint64_t>(inputs, workspaces, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, workspaces, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dtype of input must be float16 or float32 or int32, but got "
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input must be float, int or uint, but got "
<< TypeIdToType(dtype_)->ToString();
}
return true;
@ -157,10 +172,50 @@ std::vector<KernelAttr> TopKCpuKernelMod::GetOpSupport() {
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeInt32)};
return kernel_attr_list;
}

View File

@ -61,6 +61,8 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::prod_list_ = {
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
@ -69,6 +71,8 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, uint8_t)},
{REDUCE_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, uint8_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
@ -91,6 +95,8 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, uint8_t)},
{REDUCE_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, uint8_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},

View File

@ -119,6 +119,9 @@ bool ReverseV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
}
std::vector<std::pair<KernelAttr, ReverseV2GpuKernelMod::ReverseV2LaunchFunc>> ReverseV2GpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&ReverseV2GpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&ReverseV2GpuKernelMod::LaunchKernel<Complex<float>>},
@ -140,6 +143,12 @@ std::vector<std::pair<KernelAttr, ReverseV2GpuKernelMod::ReverseV2LaunchFunc>> R
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&ReverseV2GpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&ReverseV2GpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&ReverseV2GpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&ReverseV2GpuKernelMod::LaunchKernel<int8_t>},

View File

@ -32,18 +32,6 @@ using SlicePtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64),
CreateSliceKernelPtr<Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
CreateSliceKernelPtr<Complex<double>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
@ -187,7 +175,31 @@ const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
CreateSliceKernelPtr<bool, int32_t>}};
CreateSliceKernelPtr<bool, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex64),
CreateSliceKernelPtr<Complex<float>, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64),
CreateSliceKernelPtr<Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
CreateSliceKernelPtr<Complex<double>, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
CreateSliceKernelPtr<Complex<double>, int64_t>}};
} // namespace
bool SliceGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,

View File

@ -15,6 +15,7 @@
*/
#include "plugin/device/gpu/kernel/arrays/unpack_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
@ -51,5 +52,14 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Unstack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnpackFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Unstack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnpackFwdGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Unstack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
UnpackFwdGpuKernelMod, utils::Complex<float>)
MS_REG_GPU_KERNEL_ONE(
Unstack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
UnpackFwdGpuKernelMod, utils::Complex<double>)
} // namespace kernel
} // namespace mindspore

View File

@ -45,6 +45,10 @@ void CalReverseV2(const T* input, T* output, const size_t* input_shape, const in
return;
}
template CUDA_LIB_EXPORT void CalReverseV2<bool>(const bool* input, bool* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size,
size_t axis_size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReverseV2<Complex<float>>(const Complex<float>* input, Complex<float>* output,
const size_t* input_shape, const int64_t* strides,
const int64_t* axis, size_t input_size,
@ -75,6 +79,14 @@ template CUDA_LIB_EXPORT void CalReverseV2<uint16_t>(const uint16_t* input, uint
const int64_t* strides, const int64_t* axis, size_t input_size,
size_t axis_size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReverseV2<uint32_t>(const uint32_t* input, uint32_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size,
size_t axis_size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReverseV2<uint64_t>(const uint64_t* input, uint64_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size,
size_t axis_size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReverseV2<int8_t>(const int8_t* input, int8_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size,
size_t axis_size, cudaStream_t cuda_stream);

View File

@ -19,6 +19,8 @@
#include <cuda_runtime.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unpack.cuh"
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T>
__global__ void Unpack(const size_t size, const size_t output_num,
const size_t dims_after_axis, T** outputs, const T* input) {
@ -73,3 +75,12 @@ template CUDA_LIB_EXPORT void UnpackKernel(const size_t size, const size_t outpu
template CUDA_LIB_EXPORT void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, bool** outputs, const bool* input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, double** outputs, const double* input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void UnpackKernel(const size_t size, const size_t output_num, const size_t dims_after_axis,
Complex<float> **outputs, const Complex<float> *input,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void UnpackKernel(const size_t size, const size_t output_num, const size_t dims_after_axis,
Complex<double> **outputs, const Complex<double> *input,
cudaStream_t cuda_stream);

View File

@ -499,6 +499,8 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
{kRound,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&UnaryOpGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&UnaryOpGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&UnaryOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -59,7 +59,7 @@ class ReduceArithmeticInfer : public abstract::OpInferBase {
{prim::kPrimReduceMin->name(), common_valid_types_with_complex_and_bool},
{prim::kPrimReduceSum->name(), common_valid_types_with_complex_and_bool},
{prim::kPrimReduceSumD->name(), common_valid_types_with_complex_and_bool},
{prim::kPrimReduceProd->name(), common_valid_types_with_complex},
{prim::kPrimReduceProd->name(), common_valid_types_with_complex_and_bool},
{prim::kPrimReduceMean->name(), common_valid_types_with_complex},
{prim::kPrimReduceMeanD->name(), common_valid_types_with_complex},
};

View File

@ -18,6 +18,7 @@
#include <vector>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/abstract_value.h"
#include "abstract/ops/op_infer.h"
@ -59,8 +60,8 @@ class SigmoidInfer : public abstract::OpInferBase {
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
auto x_dtype = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types_with_complex_and_bool,
prim->name());
return x_dtype;
}
};

View File

@ -79,8 +79,8 @@ abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std:
TuplePtr TopKInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto output0_type = input_args[kInputIndex0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", output0_type, valid_types, prim_name);
// const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", output0_type, common_valid_types, prim_name);
auto k_type = input_args[kInputIndex1]->BuildType();
const std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTypeValid("k", k_type, int_types, prim_name);

View File

@ -983,13 +983,15 @@ class Sigmoid(Cell):
Sigmoid_function#/media/File:Logistic-curve.svg>`_.
Inputs:
- **input_x** (Tensor) - The input of Sigmoid with data type of float16 or float32. Tensor of any dimension.
- **input_x** (Tensor) - Tensor of any dimension, the data type is
float16, float32, float64, complex64 or complex128.
Outputs:
Tensor, with the same type and shape as the `input_x`.
Raises:
TypeError: If dtype of `input_x` is neither float16 nor float32.
TypeError: If dtype of `input_x` is not float16, float32, float64, complex64 or complex128.
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -543,7 +543,6 @@ def matrix_band_part(x, lower, upper):
Args:
x (Tensor): Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
The data type must be float16, float32, float64, int32 or int64.
lower (Union[int, Tensor]): Number of subdiagonals to keep. The data type must be int32 or int64.
If negative, keep entire lower triangle.
upper (Union[int, Tensor]): Number of superdiagonals to keep. The data type must be int32 or int64.
@ -554,7 +553,7 @@ def matrix_band_part(x, lower, upper):
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
TypeError: If dtype of `x` is not valid.
TypeError: If `lower` is neither a number nor a Tensor.
TypeError: If `upper` is neither a number nor a Tensor.
TypeError: If dtype of `lower` is neither int32 nor int64.

View File

@ -1592,7 +1592,6 @@ class MatrixBandPart(Primitive):
Inputs:
- **x** (Tensor) - Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
The data type must be float16, float32, float64, int32 or int64.
- **lower** (Union[int, Tensor]) - Number of subdiagonals to keep. The data type must be int32 or int64.
If negative, keep entire lower triangle.
- **upper** (Union[int, Tensor]) - Number of superdiagonals to keep. The data type must be int32 or int64.
@ -3245,7 +3244,7 @@ class ReverseV2(Primitive):
axis (Union[tuple(int), list(int)]): The indices of the dimensions to reverse.
Inputs:
- **input_x** (Tensor) - The target tensor. The data type is Number except float64.
- **input_x** (Tensor) - The target tensor.
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
Outputs:
@ -8669,7 +8668,7 @@ class TopK(Primitive):
If ``False`` , the obtained elements will not be sorted. Default: ``True`` .
Inputs:
- **input_x** (Tensor) - Input to be computed, data type must be float16, float32 or int32 on CPU,
- **input_x** (Tensor) - Input to be computed, data type can be Number on CPU,
and float16 or float32 on GPU.
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
@ -8683,7 +8682,7 @@ class TopK(Primitive):
TypeError: If `sorted` is not a bool.
TypeError: If `input_x` is not a Tensor.
TypeError: If `k` is not an int.
TypeError: If dtype of `input_x` is not one of the following: float16, float32 or int32.
TypeError: If dtype of `input_x` is not supported.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -883,7 +883,8 @@ class Sigmoid(Primitive):
Refer to :func:`mindspore.ops.sigmoid` for more details.
Inputs:
- **input_x** (Tensor) - Tensor of any dimension, the data type is float16 or float32.
- **input_x** (Tensor) - Tensor of any dimension, the data type is
float16, float32, float64, complex64 or complex128.
Outputs:
Tensor, with the same type and shape as the input_x.