!47514 fix some bugs of kernel object selection

Merge pull request !47514 from wYann/test_kobj
This commit is contained in:
i-robot 2023-02-02 06:23:45 +00:00 committed by Gitee
commit e29c9da571
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 23 additions and 48 deletions

View File

@ -1278,10 +1278,7 @@ std::vector<KernelObjectType> CalKernelObjectTypes(const std::vector<TypeId> &ob
for (size_t i = 0; i < selected_object_types.size(); ++i) {
// Allsame/skip_check doesn't support the backoff.
bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i]));
// Ops which support tensor also support scalar.
bool scalar_compact =
((selected_object_types[i] == kObjectTypeTensorType) && (object_types[i] == kObjectTypeNumber));
if (not_backoff || scalar_compact) {
if (not_backoff) {
(void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
} else {
(void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i]));

View File

@ -520,6 +520,12 @@ bool CastCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
return true;
}
std::vector<KernelAttr> CastCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr_lists.begin(), kernel_attr_lists.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CastCpuKernelFuncCreator> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cast, CastCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -48,6 +48,8 @@ class CastCpuKernelMod : public NativeCpuKernelMod {
return kernel_func_->RunFunc(inputs, workspace, outputs);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
TypeId source_dtype_{kTypeUnknown};
TypeId target_dtype_{kTypeUnknown};

View File

@ -122,27 +122,13 @@ bool EighCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
std::vector<std::tuple<KernelAttr, EighCpuKernelMod::EighFunc, EighCpuKernelMod::EighInitFunc>>
EighCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>}};
std::vector<KernelAttr> EighCpuKernelMod::GetOpSupport() {

View File

@ -597,7 +597,7 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker
!common::AnfAlgo::IsGraphKernel(kernel_node));
std::vector<kernel::KernelAttr> kernel_attrs;
if (kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
kernel_attrs = kernel::NativeGpuKernelMod::GetGpuSupportedList(kernel_name);
kernel_attrs = kernel::NativeGpuKernelModFactory::GetInstance().GetGpuSupportedList(kernel_name);
} else if (backoff_support_condition) {
// Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(kernel_name);

View File

@ -18,21 +18,11 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
EighcGpuKernelMod, Complex<float>)
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
EighcGpuKernelMod, Complex<double>)
MS_REG_GPU_KERNEL_ONE(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighcGpuKernelMod, Complex<float>)
MS_REG_GPU_KERNEL_ONE(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighcGpuKernelMod, Complex<double>);
MS_REG_GPU_KERNEL_ONE(
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
EighcGpuKernelMod, Complex<float>)
MS_REG_GPU_KERNEL_ONE(
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
EighcGpuKernelMod, Complex<double>)
} // namespace kernel
} // namespace mindspore

View File

@ -18,17 +18,11 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EighGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EighGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighGpuKernelMod, double);
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore