!35062 Refactor code style for MatchKernelHelper
Merge pull request !35062 from zhujingxuan/master
This commit is contained in:
commit
2627280d95
|
@ -315,13 +315,16 @@ inline std::map<uint32_t, tensor::TensorPtr> GetKernelDepends(const CNodePtr &cn
|
|||
template <typename Derived>
|
||||
class MatchKernelHelper {
|
||||
public:
|
||||
MatchKernelHelper() = default;
|
||||
virtual ~MatchKernelHelper() = default;
|
||||
|
||||
using KernelRunFunc = std::function<bool(Derived *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &)>;
|
||||
virtual const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const = 0;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() {
|
||||
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
|
||||
std::vector<KernelAttr> OpSupport() const {
|
||||
auto &func_list = static_cast<const Derived *>(this)->GetFuncList();
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list.begin(), func_list.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, KernelRunFunc> &pair) { return pair.first; });
|
||||
|
@ -332,7 +335,7 @@ class MatchKernelHelper {
|
|||
auto kernel_name = base_operator->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, OpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "The kernel '" << kernel_name << "' does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
|
|
|
@ -48,7 +48,7 @@ class BCEWithLogitsLossCpuKernelMod : public NativeCpuKernelMod,
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -54,7 +54,7 @@ class BitwiseCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -45,7 +45,7 @@ class CTCLossV2CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; };
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class DropoutNdCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
|
|
|
@ -48,7 +48,7 @@ class FastGeLUCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -48,7 +48,7 @@ class FastGeLUGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHe
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -52,7 +52,7 @@ class GerCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GerC
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -39,7 +39,7 @@ class HShrinkCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -39,7 +39,7 @@ class HShrinkGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHel
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -49,7 +49,7 @@ class IndexAddCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void CheckParams();
|
||||
|
|
|
@ -44,7 +44,7 @@ class IsCloseCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -47,7 +47,7 @@ class LerpCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<Ler
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -45,7 +45,7 @@ class LinSpaceCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -43,7 +43,7 @@ class LowerBoundCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelp
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename I, typename O>
|
||||
|
|
|
@ -45,7 +45,7 @@ class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename LU>
|
||||
|
|
|
@ -46,7 +46,7 @@ class PaddingCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -50,7 +50,7 @@ class ReLUV2CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<R
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -45,7 +45,7 @@ class ReLUV3CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<R
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -49,7 +49,7 @@ class ScatterNdArithmeticCpuKernelMod : public NativeCpuKernelMod,
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
|
|
|
@ -45,7 +45,7 @@ class SeluCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<Sel
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -47,7 +47,7 @@ class SoftShrinkCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelp
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -47,7 +47,7 @@ class SoftShrinkGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -44,7 +44,7 @@ class SparseAddCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
|
|
|
@ -45,7 +45,7 @@ class SparseMatirxAddCpuKernelMod : public NativeCpuKernelMod, public MatchKerne
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
|
|
|
@ -52,7 +52,7 @@ class UnsortedSegmentArithmeticCpuKernelMod : public NativeCpuKernelMod,
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
|
|
|
@ -50,7 +50,7 @@ class ScatterNdFunctorGPUKernelMod : public NativeGpuKernelMod, public MatchKern
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
|
|
|
@ -52,7 +52,7 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod,
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void FreeResource();
|
||||
|
|
|
@ -45,7 +45,7 @@ class CeluGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<Cel
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -48,7 +48,7 @@ class DropoutNDGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
|
|
|
@ -45,7 +45,7 @@ class SoftShrinkGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelp
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -45,7 +45,7 @@ class SoftShrinkGradGpuKernelMod : public NativeGpuKernelMod, public MatchKernel
|
|||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
Loading…
Reference in New Issue