!35062 Refactor code style for MatchKernelHelper

Merge pull request !35062 from zhujingxuan/master
This commit is contained in:
i-robot 2022-05-31 01:53:03 +00:00 committed by Gitee
commit 2627280d95
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
32 changed files with 37 additions and 34 deletions

View File

@ -315,13 +315,16 @@ inline std::map<uint32_t, tensor::TensorPtr> GetKernelDepends(const CNodePtr &cn
template <typename Derived>
class MatchKernelHelper {
public:
MatchKernelHelper() = default;
virtual ~MatchKernelHelper() = default;
using KernelRunFunc = std::function<bool(Derived *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &)>;
virtual const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const = 0;
protected:
std::vector<KernelAttr> GetOpSupport() {
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
std::vector<KernelAttr> OpSupport() const {
auto &func_list = static_cast<const Derived *>(this)->GetFuncList();
std::vector<KernelAttr> support_list;
(void)std::transform(func_list.begin(), func_list.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, KernelRunFunc> &pair) { return pair.first; });
@ -332,7 +335,7 @@ class MatchKernelHelper {
auto kernel_name = base_operator->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
auto [is_match, index] = MatchKernelAttr(kernel_attr, OpSupport());
if (!is_match) {
MS_LOG(ERROR) << "The kernel '" << kernel_name << "' does not support this kernel data type: " << kernel_attr;
return false;

View File

@ -48,7 +48,7 @@ class BCEWithLogitsLossCpuKernelMod : public NativeCpuKernelMod,
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -54,7 +54,7 @@ class BitwiseCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -45,7 +45,7 @@ class CTCLossV2CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; };

View File

@ -45,7 +45,7 @@ class DropoutNdCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void ResetResource() noexcept;

View File

@ -48,7 +48,7 @@ class FastGeLUCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -48,7 +48,7 @@ class FastGeLUGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHe
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -52,7 +52,7 @@ class GerCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GerC
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -39,7 +39,7 @@ class HShrinkCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -39,7 +39,7 @@ class HShrinkGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -49,7 +49,7 @@ class IndexAddCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void CheckParams();

View File

@ -44,7 +44,7 @@ class IsCloseCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -47,7 +47,7 @@ class LerpCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<Ler
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -45,7 +45,7 @@ class LinSpaceCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -43,7 +43,7 @@ class LowerBoundCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelp
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename I, typename O>

View File

@ -45,7 +45,7 @@ class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename LU>

View File

@ -46,7 +46,7 @@ class PaddingCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -50,7 +50,7 @@ class ReLUV2CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<R
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -45,7 +45,7 @@ class ReLUV3CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<R
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -49,7 +49,7 @@ class ScatterNdArithmeticCpuKernelMod : public NativeCpuKernelMod,
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename S>

View File

@ -45,7 +45,7 @@ class SeluCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<Sel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -47,7 +47,7 @@ class SoftShrinkCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelp
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -47,7 +47,7 @@ class SoftShrinkGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -44,7 +44,7 @@ class SparseAddCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelpe
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename S>

View File

@ -45,7 +45,7 @@ class SparseMatirxAddCpuKernelMod : public NativeCpuKernelMod, public MatchKerne
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename S>

View File

@ -52,7 +52,7 @@ class UnsortedSegmentArithmeticCpuKernelMod : public NativeCpuKernelMod,
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename S>

View File

@ -50,7 +50,7 @@ class ScatterNdFunctorGPUKernelMod : public NativeGpuKernelMod, public MatchKern
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename S>

View File

@ -52,7 +52,7 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod,
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void FreeResource();

View File

@ -45,7 +45,7 @@ class CeluGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<Cel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -48,7 +48,7 @@ class DropoutNDGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void ResetResource() noexcept;

View File

@ -45,7 +45,7 @@ class SoftShrinkGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelp
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>

View File

@ -45,7 +45,7 @@ class SoftShrinkGradGpuKernelMod : public NativeGpuKernelMod, public MatchKernel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>