From 8987b3dbd1293131011089d89a51f015f59d7609 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Fri, 25 Dec 2020 17:02:19 +0800 Subject: [PATCH] Add full support for cpu cast op --- .../kernel_compiler/cpu/cast_cpu_kernel.cc | 148 +++++- .../kernel_compiler/cpu/cast_cpu_kernel.h | 149 +++++- tests/st/ops/cpu/test_cast_op.py | 491 +++++++++--------- 3 files changed, 508 insertions(+), 280 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc index acf0a02bf83..b0226fd40b8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc @@ -73,40 +73,162 @@ bool CastCPUKernel::Launch(const std::vector &inputs, using TypePair = std::function &, const std::vector &)>; std::map> mode_map; - mode_map[kNumberTypeFloat32][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeBool] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeBool] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeFloat16] = LaunchCast; mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeBool][kNumberTypeUInt64] = LaunchCast; mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast; - mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeFloat32][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeFloat64][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeFloat64][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeInt8][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeInt8] = LaunchCast; mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast; mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast; mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeInt16][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeInt32][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeInt32][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeInt64][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeInt64][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeUInt8][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeInt8] = LaunchCast; mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast; mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast; mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeUInt8] = LaunchCast; mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast; mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast; mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeBool] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast; - mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeInt16] = LaunchCast; mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast; mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeUInt16] = LaunchCast; mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast; mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeBool] = LaunchCast; - mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeInt32] = LaunchCast; mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeUInt32] = LaunchCast; mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeBool] = LaunchCast; + + mode_map[kNumberTypeUInt64][kNumberTypeFloat16] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeInt8] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeUInt8] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeUInt64] = LaunchCast; + mode_map[kNumberTypeUInt64][kNumberTypeBool] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast; - mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast; - mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast; mode_map[source_dtype][target_dtype](inputs, outputs); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h index 157400ad5e0..7c68e8f6122 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h @@ -38,40 +38,161 @@ class CastCPUKernel : public CPUKernel { TypeId target_dtype{kTypeUnknown}; }; -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); -MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_cast_op.py b/tests/st/ops/cpu/test_cast_op.py index 13c36ee0df8..b4927b3ddca 100644 --- a/tests/st/ops/cpu/test_cast_op.py +++ b/tests/st/ops/cpu/test_cast_op.py @@ -36,314 +36,299 @@ class Net(Cell): @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_int32(): - x0 = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32)) - x1 = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32)) - x2 = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool)) - t = mstype.int32 +def test_cast_bool(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) + t = mstype.bool_ context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x0) - type0 = output.asnumpy().dtype - assert type0 == 'int32' - output = net(x1) - type1 = output.asnumpy().dtype - assert type1 == 'int32' - output = net(x2) - type2 = output.asnumpy().dtype - assert type2 == 'int32' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'bool' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_float16(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) + t = mstype.float16 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'float16' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_cast_float32(): - x0 = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32)) - x1 = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32)) - x2 = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool)) + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.float32 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x0) - type0 = output.asnumpy().dtype - assert type0 == 'float32' - output = net(x1) - type1 = output.asnumpy().dtype - assert type1 == 'float32' - output = net(x2) - type2 = output.asnumpy().dtype - assert type2 == 'float32' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'float32' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_int8_to_int16(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8)) +def test_cast_float64(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) + t = mstype.float64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'float64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int8(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) + t = mstype.int8 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'int8' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int16(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.int16 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int16' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'int16' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_int8_to_int32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8)) +def test_cast_int32(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.int32 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int32' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'int32' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_int8_to_int64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8)) +def test_cast_int64(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.int64 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int64' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'int64' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_uint8_to_int16(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) - t = mstype.int16 +def test_cast_uint8(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) + t = mstype.uint8 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int16' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'uint8' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_uint8_to_int32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) - t = mstype.int32 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int32' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint8_to_int64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) - t = mstype.int64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint8_to_uint16(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) +def test_cast_uint16(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.uint16 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'uint16' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'uint16' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_uint8_to_uint32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) +def test_cast_uint32(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.uint32 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'uint32' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'uint32' @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_cast_uint8_to_uint64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) +def test_cast_uint64(): + tensor_to_cast = [] + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.bool))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int64))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))) + tensor_to_cast.append(Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint64))) t = mstype.uint64 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'uint64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_int16_to_int32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16)) - t = mstype.int32 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int32' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_int16_to_int64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16)) - t = mstype.int64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint16_to_int32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) - t = mstype.int32 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int32' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint16_to_int64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) - t = mstype.int64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint16_to_uint32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) - t = mstype.uint32 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'uint32' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint16_to_uint64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) - t = mstype.uint64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'uint64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_int32_to_int64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32)) - t = mstype.int64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint32_to_int64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32)) - t = mstype.int64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'int64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_uint32_to_uint64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32)) - t = mstype.uint64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'uint64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_float16_to_float32(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16)) - t = mstype.float32 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'float32' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_float16_to_float64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16)) - t = mstype.float64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'float64' - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cast_float32_to_float64(): - x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32)) - t = mstype.float64 - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - net = Net(t) - output = net(x) - dtype = output.asnumpy().dtype - assert dtype == 'float64' + for tensor in tensor_to_cast: + net = Net(t) + output = net(tensor) + assert output.asnumpy().dtype == 'uint64'