forked from mindspore-Ecosystem/mindspore
!47514 fix some bugs of kernel object selection
Merge pull request !47514 from wYann/test_kobj
This commit is contained in:
commit
e29c9da571
|
@ -1278,10 +1278,7 @@ std::vector<KernelObjectType> CalKernelObjectTypes(const std::vector<TypeId> &ob
|
|||
for (size_t i = 0; i < selected_object_types.size(); ++i) {
|
||||
// Allsame/skip_check doesn't support the backoff.
|
||||
bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i]));
|
||||
// Ops which support tensor also support scalar.
|
||||
bool scalar_compact =
|
||||
((selected_object_types[i] == kObjectTypeTensorType) && (object_types[i] == kObjectTypeNumber));
|
||||
if (not_backoff || scalar_compact) {
|
||||
if (not_backoff) {
|
||||
(void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
|
||||
} else {
|
||||
(void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i]));
|
||||
|
|
|
@ -520,6 +520,12 @@ bool CastCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CastCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr_lists.begin(), kernel_attr_lists.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CastCpuKernelFuncCreator> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cast, CastCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,8 @@ class CastCpuKernelMod : public NativeCpuKernelMod {
|
|||
return kernel_func_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
TypeId source_dtype_{kTypeUnknown};
|
||||
TypeId target_dtype_{kTypeUnknown};
|
||||
|
|
|
@ -122,27 +122,13 @@ bool EighCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
|
|||
|
||||
std::vector<std::tuple<KernelAttr, EighCpuKernelMod::EighFunc, EighCpuKernelMod::EighInitFunc>>
|
||||
EighCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>}};
|
||||
|
||||
std::vector<KernelAttr> EighCpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -597,7 +597,7 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker
|
|||
!common::AnfAlgo::IsGraphKernel(kernel_node));
|
||||
std::vector<kernel::KernelAttr> kernel_attrs;
|
||||
if (kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
|
||||
kernel_attrs = kernel::NativeGpuKernelMod::GetGpuSupportedList(kernel_name);
|
||||
kernel_attrs = kernel::NativeGpuKernelModFactory::GetInstance().GetGpuSupportedList(kernel_name);
|
||||
} else if (backoff_support_condition) {
|
||||
// Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
|
||||
kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(kernel_name);
|
||||
|
|
|
@ -18,21 +18,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EighcGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
EighcGpuKernelMod, Complex<double>)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
EighcGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EighcGpuKernelMod, Complex<double>);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,17 +18,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EighGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EighGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EighGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Eigh,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EighGpuKernelMod, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue