diff --git a/mindspore/ccsrc/kernel/common_utils.cc b/mindspore/ccsrc/kernel/common_utils.cc index f8c71b149f4..19c0a922a6f 100644 --- a/mindspore/ccsrc/kernel/common_utils.cc +++ b/mindspore/ccsrc/kernel/common_utils.cc @@ -1278,10 +1278,7 @@ std::vector CalKernelObjectTypes(const std::vector &ob for (size_t i = 0; i < selected_object_types.size(); ++i) { // Allsame/skip_check doesn't support the backoff. bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i])); - // Ops which support tensor also support scalar. - bool scalar_compact = - ((selected_object_types[i] == kObjectTypeTensorType) && (object_types[i] == kObjectTypeNumber)); - if (not_backoff || scalar_compact) { + if (not_backoff) { (void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i])); } else { (void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i])); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc index 8ba284b9729..4e6303b2d20 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.cc @@ -520,6 +520,12 @@ bool CastCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec return true; } +std::vector CastCpuKernelMod::GetOpSupport() { + static std::vector support_list; + (void)std::transform(kernel_attr_lists.begin(), kernel_attr_lists.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cast, CastCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.h index c02f3904199..bef9ecb89f0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cast_cpu_kernel.h @@ -48,6 +48,8 @@ class CastCpuKernelMod : public NativeCpuKernelMod { return kernel_func_->RunFunc(inputs, workspace, outputs); } + std::vector GetOpSupport() override; + private: TypeId source_dtype_{kTypeUnknown}; TypeId target_dtype_{kTypeUnknown}; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/eigh_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/eigh_cpu_kernel.cc index 0f675ea7a3a..1423c8b726c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/eigh_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/eigh_cpu_kernel.cc @@ -122,27 +122,13 @@ bool EighCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector> EighCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr() - .AddInputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64), - &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}, - {KernelAttr() - .AddInputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128), + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), &EighCpuKernelMod::LaunchKernel, &EighCpuKernelMod::InitIOFunc}}; std::vector EighCpuKernelMod::GetOpSupport() { diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc b/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc index 82e597151b7..ef8609cdef2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc @@ -597,7 +597,7 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker !common::AnfAlgo::IsGraphKernel(kernel_node)); std::vector kernel_attrs; if (kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) { - kernel_attrs = kernel::NativeGpuKernelMod::GetGpuSupportedList(kernel_name); + kernel_attrs = kernel::NativeGpuKernelModFactory::GetInstance().GetGpuSupportedList(kernel_name); } else if (backoff_support_condition) { // Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type. kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(kernel_name); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_c_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_c_gpu_kernel.cc index 3cbc1b2260e..8f6c4334434 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_c_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_c_gpu_kernel.cc @@ -18,21 +18,11 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - EighcGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - EighcGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Eigh, - KernelAttr() - .AddInputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64), - EighcGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Eigh, - KernelAttr() - .AddInputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128), - EighcGpuKernelMod, Complex); +MS_REG_GPU_KERNEL_ONE( + Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + EighcGpuKernelMod, Complex) +MS_REG_GPU_KERNEL_ONE( + Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + EighcGpuKernelMod, Complex) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_gpu_kernel.cc index 97dc50a3e7d..5af91ec56df 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/eigh_gpu_kernel.cc @@ -18,17 +18,11 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EighGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - EighGpuKernelMod, double) MS_REG_GPU_KERNEL_ONE( - Eigh, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EighGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE( - Eigh, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - EighGpuKernelMod, double); + Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + EighGpuKernelMod, double) } // namespace kernel } // namespace mindspore