clean code

This commit is contained in:
TronZhang 2022-07-30 11:31:32 +08:00
parent 4c9b6b06cb
commit a610b338e1
363 changed files with 375 additions and 630 deletions

View File

@ -34,7 +34,6 @@ class AdamCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true).AddOutInRef(0, 0)};
return support_list;

View File

@ -32,7 +32,6 @@ class AdamDeltaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class AddcdivCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class AddcmulCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -33,7 +33,6 @@ class AdjustContrastv2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -44,7 +44,6 @@ class AdjustHueCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -33,11 +33,10 @@ class AdjustSaturationCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
std::vector<KernelAttr> GetOpSupport() override;
private:
TypeId input_type_{kTypeUnknown};
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore

View File

@ -35,7 +35,6 @@ class AllGatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -35,7 +35,6 @@ class AllReduceCPUKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel

View File

@ -37,7 +37,6 @@ class AngleCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -44,7 +44,6 @@ class ApplyAdaMaxCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -43,7 +43,6 @@ class ApplyAdadeltaCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
@ -62,6 +61,7 @@ class ApplyAdadeltaCpuKernelMod : public NativeCpuKernelMod {
return support_list;
}
protected:
int CheckInputShape(const std::vector<KernelTensorPtr> &inputs);
int CheckShapeSize(std::vector<int64_t> var_shape, std::vector<int64_t> lr_shape);

View File

@ -35,7 +35,6 @@ class ApplyAdagradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -42,7 +42,6 @@ class ApplyAdagradDACpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -46,7 +46,6 @@ class ApplyAdagradV2CpuKernelMod : public NativeCpuKernelMod {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -40,7 +40,6 @@ class ApplyAdamWithAmsgradCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -39,7 +39,6 @@ class ApplyFtrlCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -38,7 +38,6 @@ class ApplyGradientDescentCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat16)

View File

@ -35,7 +35,6 @@ class BACKEND_EXPORT ApplyMomentumCpuKernelMod : public DeprecatedNativeCpuKerne
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel

View File

@ -43,7 +43,6 @@ class ApplyProximalAdagradCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -38,7 +38,6 @@ class ApproximateEqualCpuKernelMod : public NativeCpuKernelMod {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -65,7 +65,7 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
for (int64_t j = 0; j < num_after_axis_; j++) {
int64_t src_index_j = src_index_i + j;
for (int64_t k = 0; k < dim_axis_; k++) {
int64_t src_index_k = k * num_after_axis_ + src_index_j;
auto src_index_k = LongToSize(k * num_after_axis_ + src_index_j);
array_axis[k] = static_cast<float>(input[src_index_k]);
}
auto max_ops = std::max_element(array_axis.begin(), array_axis.end());

View File

@ -42,7 +42,6 @@ class ArgminCpuKernelMod : public NativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -37,11 +37,22 @@ namespace {
constexpr float kMaxSubSerialSize = 10000.0;
constexpr float kMaxPowSerialSize = 700.0;
constexpr auto kAdd = "Add";
constexpr auto kAddV2 = "AddV2";
constexpr auto kSub = "Sub";
constexpr auto kMul = "Mul";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kAssignAdd = "AssignAdd";
constexpr auto kAssignSub = "AssignSub";
constexpr auto kDiv = "Div";
constexpr auto kDivNoNan = "DivNoNan";
constexpr auto kPow = "Pow";
constexpr auto kFloorDiv = "FloorDiv";
constexpr auto kMod = "Mod";
constexpr auto kFloorMod = "FloorMod";
constexpr auto kSquaredDifference = "SquaredDifference";
constexpr auto kXlogy = "Xlogy";
constexpr auto kAtan2 = "Atan2";
template <typename T>
void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_t delta_1, size_t delta_2) {
@ -157,9 +168,9 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
MS_LOG(WARNING) << kernel_name_ << " output shape contain 0, output_shape: " << output_shape_;
return true;
}
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
if (kernel_name_ == kAssignAdd) {
AssignAdd(input1, input2, output);
} else if (kernel_name_ == prim::kPrimAssignSub->name()) {
} else if (kernel_name_ == kAssignSub) {
AssignSub(input1, input2, output);
} else {
compute_func_(this, input1, input2, output);
@ -169,39 +180,38 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
private:
void InitComputeFunc() {
if (kernel_name_ == prim::kPrimAssignAdd->name() || kernel_name_ == prim::kPrimAssignSub->name()) {
if (kernel_name_ == kAssignAdd || kernel_name_ == kAssignSub) {
return;
}
string dtype_desc;
static std::unordered_map<std::string, TypeComputeFunc> arithmeticMathFuncMap;
if constexpr (!((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>))) {
dtype_desc = "real data";
arithmeticMathFuncMap = {{prim::kPrimAdd->name(), &ArithmeticCpuTypeFunc<T>::Add},
{prim::kPrimAddV2->name(), &ArithmeticCpuTypeFunc<T>::AddV2},
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::Div},
{prim::kPrimDivNoNan->name(), &ArithmeticCpuTypeFunc<T>::DivNoNan},
{prim::kPrimMod->name(), &ArithmeticCpuTypeFunc<T>::Mod},
{prim::kPrimFloorMod->name(), &ArithmeticCpuTypeFunc<T>::FloorMod},
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::Pow},
{prim::kPrimFloorDiv->name(), &ArithmeticCpuTypeFunc<T>::FloorDiv},
{prim::kPrimAtan2->name(), &ArithmeticCpuTypeFunc<T>::Atan2},
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDiv},
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuTypeFunc<T>::SquaredDifference},
{prim::kPrimXlogy->name(), &ArithmeticCpuTypeFunc<T>::Xlogy}};
arithmeticMathFuncMap = {{kAdd, &ArithmeticCpuTypeFunc<T>::Add},
{kAddV2, &ArithmeticCpuTypeFunc<T>::AddV2},
{kSub, &ArithmeticCpuTypeFunc<T>::Sub},
{kMul, &ArithmeticCpuTypeFunc<T>::Mul},
{kDiv, &ArithmeticCpuTypeFunc<T>::Div},
{kDivNoNan, &ArithmeticCpuTypeFunc<T>::DivNoNan},
{kMod, &ArithmeticCpuTypeFunc<T>::Mod},
{kFloorMod, &ArithmeticCpuTypeFunc<T>::FloorMod},
{kPow, &ArithmeticCpuTypeFunc<T>::Pow},
{kFloorDiv, &ArithmeticCpuTypeFunc<T>::FloorDiv},
{kAtan2, &ArithmeticCpuTypeFunc<T>::Atan2},
{kRealDiv, &ArithmeticCpuTypeFunc<T>::RealDiv},
{kSquaredDifference, &ArithmeticCpuTypeFunc<T>::SquaredDifference},
{kXlogy, &ArithmeticCpuTypeFunc<T>::Xlogy}};
} else {
dtype_desc = "complex data";
arithmeticMathFuncMap = {
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuTypeFunc<T>::SquaredDifferenceComplex},
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::DivComplex},
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDivComplex},
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
{prim::kPrimDivNoNan->name(), &ArithmeticCpuTypeFunc<T>::DivNoNan},
{prim::kPrimAddV2->name(), &ArithmeticCpuTypeFunc<T>::AddV2},
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::PowComplex},
{prim::kPrimXlogy->name(), &ArithmeticCpuTypeFunc<T>::Xlogy}};
arithmeticMathFuncMap = {{kSquaredDifference, &ArithmeticCpuTypeFunc<T>::SquaredDifferenceComplex},
{kSub, &ArithmeticCpuTypeFunc<T>::Sub},
{kDiv, &ArithmeticCpuTypeFunc<T>::DivComplex},
{kRealDiv, &ArithmeticCpuTypeFunc<T>::RealDivComplex},
{kMul, &ArithmeticCpuTypeFunc<T>::Mul},
{kDivNoNan, &ArithmeticCpuTypeFunc<T>::DivNoNan},
{kAddV2, &ArithmeticCpuTypeFunc<T>::AddV2},
{kPow, &ArithmeticCpuTypeFunc<T>::PowComplex},
{kXlogy, &ArithmeticCpuTypeFunc<T>::Xlogy}};
}
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
MS_LOG(EXCEPTION) << "For 'Arithmetic', only supports operators in " << Map2Str(arithmeticMathFuncMap)
@ -810,7 +820,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
SpecializeArithFunc<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithFunc<bool>}}},
{prim::kPrimDiv->name(),
{kDiv,
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpecializeArithFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
@ -839,7 +849,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimDivNoNan->name(),
{kDivNoNan,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
@ -856,7 +866,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimPow->name(),
{kPow,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
@ -910,7 +920,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimFloorDiv->name(),
{kFloorDiv,
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpecializeArithFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
@ -927,14 +937,14 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimMod->name(),
{kMod,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>}}},
{prim::kPrimFloorMod->name(),
{kFloorMod,
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
@ -1013,7 +1023,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddOutputAttr(kNumberTypeFloat64)
.AddOutInRef(0, 0),
SpecializeArithFunc<double>}}},
{prim::kPrimSquaredDifference->name(),
{kSquaredDifference,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
@ -1034,7 +1044,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimXlogy->name(),
{kXlogy,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
@ -1051,12 +1061,12 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimAtan2->name(),
{kAtan2,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimAddV2->name(),
{kAddV2,
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpecializeArithFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
@ -1118,35 +1128,31 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sub,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mul,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMul); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Div,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDiv->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DivNoNan, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDivNoNan->name());
});
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kDiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DivNoNan,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kDivNoNan); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Pow,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimPow->name()); });
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kPow); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RealDiv,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kRealDiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorDiv, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorDiv->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorDiv,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kFloorDiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mod,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimMod->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorMod, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorMod->name());
});
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMod); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorMod,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kFloorMod); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AssignAdd,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAssignAdd); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AssignSub,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAssignSub); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SquaredDifference, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimSquaredDifference->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SquaredDifference,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kSquaredDifference); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Xlogy,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimXlogy->name()); });
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kXlogy); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atan2,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimAtan2->name()); });
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAtan2); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AddV2,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimAddV2->name()); });
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAddV2); });
} // namespace kernel
} // namespace mindspore

View File

@ -47,7 +47,6 @@ class ArithmeticCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -34,6 +34,13 @@ using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr size_t kMaxLessSerialSize = 15000;
constexpr auto kLess = "Less";
constexpr auto kLessEqual = "LessEqual";
constexpr auto kGreater = "Greater";
constexpr auto kGreaterEqual = "GreaterEqual";
constexpr auto kLogicalAnd = "LogicalAnd";
constexpr auto kLogicalOr = "LogicalOr";
constexpr auto kLogicalXor = "LogicalXor";
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
@ -85,16 +92,16 @@ class ArithLogicCpuTypeFunc : public DeprecatedCpuKernelFunc {
}
static std::unordered_map<std::string, TypeComputeFunc> arithmetic_logic_func_map;
if constexpr (!((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>))) {
arithmetic_logic_func_map = {{prim::kPrimGreater->name(), &ArithLogicCpuTypeFunc<T>::Greater},
{prim::kPrimGreaterEqual->name(), &ArithLogicCpuTypeFunc<T>::GreaterEqual},
{prim::kPrimLogicalAnd->name(), &ArithLogicCpuTypeFunc<T>::LogicalAnd},
{prim::kPrimLessEqual->name(), &ArithLogicCpuTypeFunc<T>::LessEqual},
{prim::kPrimLogicalOr->name(), &ArithLogicCpuTypeFunc<T>::LogicalOr},
{prim::kPrimLogicalXor->name(), &ArithLogicCpuTypeFunc<T>::LogicalXor},
{prim::kPrimLess->name(), &ArithLogicCpuTypeFunc<T>::Less},
{prim::kPrimNotEqual->name(), &ArithLogicCpuTypeFunc<T>::NotEqual}};
arithmetic_logic_func_map = {{kGreater, &ArithLogicCpuTypeFunc<T>::Greater},
{kGreaterEqual, &ArithLogicCpuTypeFunc<T>::GreaterEqual},
{kLogicalAnd, &ArithLogicCpuTypeFunc<T>::LogicalAnd},
{kLessEqual, &ArithLogicCpuTypeFunc<T>::LessEqual},
{kLogicalOr, &ArithLogicCpuTypeFunc<T>::LogicalOr},
{kLogicalXor, &ArithLogicCpuTypeFunc<T>::LogicalXor},
{kLess, &ArithLogicCpuTypeFunc<T>::Less},
{kNotEqual, &ArithLogicCpuTypeFunc<T>::NotEqual}};
} else {
arithmetic_logic_func_map = {{prim::kPrimNotEqual->name(), &ArithLogicCpuTypeFunc<T>::NotEqual}};
arithmetic_logic_func_map = {{kNotEqual, &ArithLogicCpuTypeFunc<T>::NotEqual}};
}
if (arithmetic_logic_func_map.find(kernel_name_) == arithmetic_logic_func_map.end()) {
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in " << Map2Str(arithmetic_logic_func_map)
@ -114,7 +121,7 @@ class ArithLogicCpuTypeFunc : public DeprecatedCpuKernelFunc {
void LessEqual(const T *input1, const T *input2, bool *out);
void LogicalAnd(const T *input1, const T *input2, bool *out);
void LogicalOr(const T *input1, const T *input2, bool *out);
void LogicalXor(const T *input1, const T *input2, bool *out);
void LogicalXor(const T *input1, const T *input2, bool *out) const;
using TypeComputeFunc = std::function<void(ArithLogicCpuTypeFunc *, const T *, const T *, bool *)>;
TypeComputeFunc compute_func_{nullptr};
@ -178,7 +185,7 @@ class ArithComplexLogicCpuTypeFunc : public DeprecatedCpuKernelFunc {
<< dtype_ << ", and the type of 'input2': " << dtype_1;
}
static const std::unordered_map<std::string, ComplexTypeComputeFunc> arithmetic_logic_func_map{
{prim::kPrimEqual->name(), &ArithComplexLogicCpuTypeFunc<T>::Equal}};
{kEqual, &ArithComplexLogicCpuTypeFunc<T>::Equal}};
if (arithmetic_logic_func_map.find(kernel_name_) == arithmetic_logic_func_map.end()) {
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in " << Map2Str(arithmetic_logic_func_map)
<< ", but got " << kernel_name_;
@ -329,7 +336,7 @@ void ArithLogicCpuTypeFunc<T>::LogicalOr(const T *input1, const T *input2, bool
}
template <typename T>
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) const {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -367,7 +374,7 @@ std::shared_ptr<DeprecatedCpuKernelFunc> SpecializeArithLogComplexFunc() {
}
using ArithLogicCpuFuncCreator = std::function<std::shared_ptr<DeprecatedCpuKernelFunc>()>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFuncCreator>>> kernel_attr_lists = {
{prim::kPrimLess->name(),
{kLess,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
@ -438,7 +445,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<complex128>}}},
{prim::kPrimGreater->name(),
{kGreater,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
@ -447,7 +454,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
SpecializeArithLogFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>}}},
{prim::kPrimGreaterEqual->name(),
{kGreaterEqual,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
@ -462,7 +469,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{prim::kPrimLessEqual->name(),
{kLessEqual,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
@ -471,13 +478,13 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{prim::kPrimLogicalAnd->name(),
{kLogicalAnd,
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}},
{prim::kPrimLogicalOr->name(),
{kLogicalOr,
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}},
{prim::kPrimLogicalXor->name(),
{kLogicalXor,
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}}};
} // namespace
@ -540,30 +547,23 @@ std::vector<KernelAttr> ArithmeticComplexLogicCpuKernelMod::GetOpSupport() {
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Less, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLess->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Less,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLess); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Equal,
[]() { return std::make_shared<ArithmeticComplexLogicCpuKernelMod>(kEqual); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, NotEqual,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kNotEqual); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Greater, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreater->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GreaterEqual, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreaterEqual->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LessEqual, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLessEqual->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalAnd, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalAnd->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalOr, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalOr->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalXor, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalXor->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Greater,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kGreater); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GreaterEqual,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kGreaterEqual); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LessEqual,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLessEqual); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalAnd,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLogicalAnd); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalOr,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLogicalOr); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalXor,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLogicalXor); });
} // namespace kernel
} // namespace mindspore

View File

@ -44,13 +44,13 @@ class ArithmeticLogicCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::shared_ptr<DeprecatedCpuKernelFunc> func_obj_;
std::string kernel_type_{"Unknown"};
};
class ArithmeticComplexLogicCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
ArithmeticComplexLogicCpuKernelMod() = default;
@ -68,7 +68,6 @@ class ArithmeticComplexLogicCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -115,7 +115,7 @@ void Sign(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size
}
template <typename T>
void Neg(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
void Neg(ArithmeticSelfCpuKernelFunc *, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = -in[i];
@ -511,7 +511,7 @@ void Rsqrt(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t siz
}
template <typename T>
void Softsign(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
void Softsign(ArithmeticSelfCpuKernelFunc *, const T *in, T *out, size_t size) {
if constexpr ((std::is_same_v<T, uint8_t>) || (std::is_same_v<T, uint16_t>) || (std::is_same_v<T, uint32_t>) ||
(std::is_same_v<T, uint64_t>)) {
MS_LOG(EXCEPTION) << "'Softsign' cannot be instantiated.";

View File

@ -48,7 +48,6 @@ class ArithmeticSelfCpuKernelMod : public NativeCpuKernelMod {
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
@ -76,7 +75,6 @@ class IdentityCpuKernelMod : public NativeCpuKernelMod {
return kernel_func_(inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class AssignCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -48,7 +48,6 @@ class BatchToSpaceNDCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -43,7 +43,6 @@ class BesselI0CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselI0Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
@ -78,7 +77,6 @@ class BesselI0eCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselI0eFunc(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -44,7 +44,6 @@ class BesselI1CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselI1Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
@ -78,7 +77,6 @@ class BesselI1eCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselI1eFunc(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -46,7 +46,6 @@ class BesselJ0CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselJ0Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -46,7 +46,6 @@ class BesselJ1CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselJ1Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -43,7 +43,6 @@ class BesselK0CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselK0Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
@ -77,7 +76,6 @@ class BesselK0eCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselK0eFunc(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -43,7 +43,6 @@ class BesselK1CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselK1Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
@ -77,7 +76,6 @@ class BesselK1eCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselK1eFunc(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -45,7 +45,6 @@ class BesselY0CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselY0Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -45,7 +45,6 @@ class BesselY1CpuKernelMod : public NativeCpuKernelMod {
template <typename T>
static void BesselY1Func(const T *input, T *output, size_t start, size_t end);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class BinaryCrossEntropyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -35,7 +35,6 @@ class BinaryCrossEntropyGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -35,7 +35,6 @@ class BlackmanWindowCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -40,7 +40,6 @@ class BoundingBoxDecodeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -39,7 +39,6 @@ class BoundingBoxEncodeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -47,7 +47,6 @@ class BroadcastToCpuKernelMod : public DeprecatedNativeCpuKernelMod {
void CheckArgs();
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -35,7 +35,6 @@ class BucketizeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
bool BucketizeCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -37,7 +37,6 @@ class CheckNumericsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -38,7 +38,6 @@ class CheckValidCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class CholeskyInverseCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -39,7 +39,6 @@ class CoalesceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -40,7 +40,6 @@ class ComplexCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -42,7 +42,6 @@ class ConcatCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -52,7 +52,7 @@ void ConcatOffsetCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
auto node_ = cnode_ptr_.lock();

View File

@ -37,12 +37,11 @@ class ConcatOffsetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
bool LaunchKernel(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &outputs);
using ConcatOffsetFunc = std::function<bool(ConcatOffsetCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, ConcatOffsetFunc>> func_list_;

View File

@ -48,7 +48,7 @@ std::vector<KernelAttr> NativeCpuKernelMod::GetAllSupportedList(const std::strin
return support_map_[kernel_name];
}
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) {
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) const {
static std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
std::vector<KernelAttr> support_kernel_attrs;
@ -134,7 +134,7 @@ void DeprecatedNativeCpuKernelMod::Init(const CNodePtr &kernel_node) {
InitInputOutputSize(kernel_node);
}
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetInputDtypes(const CNodePtr &kernel_node) {
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetInputDtypes(const CNodePtr &kernel_node) const {
std::vector<TypeId> input_types;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
@ -144,7 +144,7 @@ std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetInputDtypes(const CNodePtr
return input_types;
}
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetOutputDtypes(const CNodePtr &kernel_node) {
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetOutputDtypes(const CNodePtr &kernel_node) const {
std::vector<TypeId> output_types;
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {

View File

@ -167,7 +167,7 @@ class BACKEND_EXPORT NativeCpuKernelMod : public CpuKernelMod {
private:
std::vector<KernelAttr> GetAllSupportedList(const std::string &kernel_name);
std::vector<KernelAttr> GetSupportFromOpLib(const std::string &kernel_name);
std::vector<KernelAttr> GetSupportFromOpLib(const std::string &kernel_name) const;
static mindspore::HashMap<std::string, std::vector<KernelAttr>> support_map_;
};
@ -208,8 +208,8 @@ class BACKEND_EXPORT DeprecatedNativeCpuKernelMod : public NativeCpuKernelMod {
}
private:
std::vector<TypeId> GetInputDtypes(const CNodePtr &kernel_node);
std::vector<TypeId> GetOutputDtypes(const CNodePtr &kernel_node);
std::vector<TypeId> GetInputDtypes(const CNodePtr &kernel_node) const;
std::vector<TypeId> GetOutputDtypes(const CNodePtr &kernel_node) const;
};
class DeprecatedCpuKernelFunc {

View File

@ -53,7 +53,6 @@ class CropAndResizeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -65,7 +65,6 @@ class CropAndResizeGradBoxesCpuKernelMod : public DeprecatedNativeCpuKernelMod {
void OutputZeroing(const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class CrossCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -48,7 +48,6 @@ class CTCGreedyDecoderCpuKernelMod : public NativeCpuKernelMod, public MatchKern
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:

View File

@ -38,7 +38,6 @@ class CTCLossCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -48,7 +48,6 @@ class DataFormatDimMapCpuKernelMod : public NativeCpuKernelMod, public MatchKern
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:

View File

@ -39,7 +39,6 @@ class DataFormatVecPermuteCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -19,7 +19,6 @@
#include <vector>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
@ -36,7 +35,6 @@ class DebugCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel

View File

@ -67,7 +67,6 @@ class DeformableOffsetsGradCpuKernelMod : public NativeCpuKernelMod,
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::OpSupport(); }
private:

View File

@ -38,7 +38,6 @@ class DenseToCSRSparseMatrixCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -38,7 +38,6 @@ class DenseToDenseSetOperationCpuKernelMod : public DeprecatedNativeCpuKernelMod
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class DepthToSpaceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -58,7 +58,7 @@ bool DynamicAssignCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
template <typename T>
void DynamicAssignCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
const std::vector<kernel::AddressPtr> &) {
auto node = node_wpt_.lock();
if (!node) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_(kernel_node) is expired. Error no: " << node;

View File

@ -36,12 +36,11 @@ class DynamicAssignCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &);
size_t batch_size_{1};
TypeId input_x_dtype_{kTypeUnknown};

View File

@ -34,7 +34,6 @@ class CholeskyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -37,7 +37,6 @@ class CholeskySolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -39,7 +39,6 @@ class EigCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -39,7 +39,6 @@ class EighCpuKernelMod : public DeprecatedNativeCpuKernelMod {
}
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -43,7 +43,6 @@ class ExpandCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <size_t RANK, typename T>
bool ExpandCalculate(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -50,7 +50,7 @@ void LUCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *ro
}
}
void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) const {
constexpr size_t pivot_min_dim = 1;
if (shape.size() < pivot_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid.";
@ -99,14 +99,14 @@ void LUCpuKernelMod::InitIOSize(const CNodePtr &kernel_node) {
}
template <typename T>
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) const {
const T *pered_lu_value = lu_value + per_value[i] * SizeToInt(lu_col_) + SizeToInt(j);
return *pered_lu_value;
}
template <typename T>
bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
size_t rows) {
size_t rows) const {
T max_major_value = static_cast<T>(kZeroThreshold);
size_t max_major_index = 0;
for (size_t i = k; i < rows; ++i) {
@ -126,7 +126,7 @@ bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_v
template <typename T>
void LUCpuKernelMod::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
const T &value) {
const T &value) const {
T *per_lu_value = lu_value + per_value[i] * SizeToInt(lu_col_) + SizeToInt(j);
*per_lu_value = value;
}

View File

@ -34,16 +34,15 @@ class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) const;
template <typename T>
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows) const;
template <typename T>
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value) const;
template <typename T>
void InitIOSize(const CNodePtr &kernel_node);
template <typename T>
@ -59,7 +58,7 @@ class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) const;
size_t batch_size_{1};
size_t a_row_{1};
size_t a_col_{1};

View File

@ -35,7 +35,6 @@ class LUSolverCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -35,7 +35,6 @@ class MatrixTriangularSolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -34,7 +34,6 @@ class QRCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -40,7 +40,6 @@ class RandomPoissonCpuKernelMod : public NativeCpuKernelMod, public MatchKernelH
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:

View File

@ -62,7 +62,6 @@ class EltWiseGradCpuKernelMod : public NativeCpuKernelMod {
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -36,7 +36,6 @@ class EluGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -34,7 +34,6 @@ class EmbeddingLookUpCommGradCpuKernelMod : public DeprecatedNativeCpuKernelMod
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};

View File

@ -35,7 +35,6 @@ class BACKEND_EXPORT EmbeddingLookUpCpuKernelMod : public DeprecatedNativeCpuKer
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
@ -44,6 +43,7 @@ class BACKEND_EXPORT EmbeddingLookUpCpuKernelMod : public DeprecatedNativeCpuKer
return support_list;
}
protected:
template <typename T>
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);

View File

@ -32,7 +32,6 @@ class EnvironCreateCpuKernelMod : public DeprecatedNativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) override;
void InitKernel(const CNodePtr &node) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeInt64)};
return support_list;

View File

@ -32,7 +32,6 @@ class EnvironDestroyAllCpuKernelMod : public DeprecatedNativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) override;
void InitKernel(const CNodePtr &node);
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeBool)};
return support_list;

View File

@ -32,7 +32,6 @@ class EnvironGetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) override;
void InitKernel(const CNodePtr &node) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeInt64)

View File

@ -33,7 +33,6 @@ class EnvironSetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) override;
void InitKernel(const CNodePtr &node) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeInt64)

View File

@ -61,7 +61,7 @@ bool Expm1CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
template <typename T>
void Expm1CpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
const std::vector<kernel::AddressPtr> &outputs) const {
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t elem_num = inputs[0]->size / sizeof(T);

View File

@ -37,7 +37,6 @@ class Expm1CpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -50,7 +49,7 @@ class Expm1CpuKernelMod : public DeprecatedNativeCpuKernelMod {
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) const;
TypeId input_dtype_{kTypeUnknown};
};

View File

@ -38,7 +38,6 @@ class EyeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -34,7 +34,6 @@ class FillDiagonalCpuKernelMod : public DeprecatedNativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -41,7 +41,6 @@ class FillV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
void CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:

View File

@ -21,7 +21,7 @@ namespace kernel {
constexpr int iv_vec_len = 16;
constexpr int salt_len = 32;
bool ExchangeKeysKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool ExchangeKeysKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) {
MS_LOG(INFO) << "Launching client ExchangeKeysKernelMod";
if (!BuildExchangeKeysReq(fbb_)) {
@ -81,7 +81,7 @@ void ExchangeKeysKernelMod::Init(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "Initialize ExchangeKeys kernel successfully.";
}
void ExchangeKeysKernelMod::InitKernel(const CNodePtr &kernel_node) { return; }
void ExchangeKeysKernelMod::InitKernel(const CNodePtr &) { return; }
bool ExchangeKeysKernelMod::BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
MS_EXCEPTION_IF_NULL(fbb);
@ -128,7 +128,7 @@ bool ExchangeKeysKernelMod::BuildExchangeKeysReq(const std::shared_ptr<fl::FBBui
return true;
}
std::vector<uint8_t> ExchangeKeysKernelMod::GetPubicKeyBytes() {
std::vector<uint8_t> ExchangeKeysKernelMod::GetPubicKeyBytes() const {
// generate private key of secret
armour::PrivateKey *sPriKeyPtr = armour::KeyAgreement::GeneratePrivKey();
fl::worker::FLWorker::GetInstance().set_secret_pk(sPriKeyPtr);
@ -164,7 +164,7 @@ std::vector<uint8_t> ExchangeKeysKernelMod::GetPubicKeyBytes() {
}
std::vector<KernelAttr> ExchangeKeysKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
const std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}

View File

@ -32,26 +32,25 @@ class ExchangeKeysKernelMod : public DeprecatedNativeCpuKernelMod {
ExchangeKeysKernelMod() = default;
~ExchangeKeysKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) override;
void Init(const CNodePtr &kernel_node) override;
void InitKernel(const CNodePtr &kernel_node) override;
void InitKernel(const CNodePtr &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
bool BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
std::vector<uint8_t> GetPubicKeyBytes();
std::vector<uint8_t> GetPubicKeyBytes() const;
uint32_t rank_id_;
uint32_t server_num_;
uint32_t target_server_rank_;
uint32_t rank_id_{0};
uint32_t server_num_{0};
uint32_t target_server_rank_{0};
std::string fl_id_;
std::shared_ptr<fl::FBBuilder> fbb_;
armour::PrivateKey *secret_prikey_;
armour::PrivateKey *secret_prikey_{nullptr};
};
} // namespace kernel
} // namespace mindspore

View File

@ -144,11 +144,12 @@ class FusedPullWeightKernelMod : public DeprecatedNativeCpuKernelMod {
init_func_(this, kernel_node);
}
void InitKernel(const CNodePtr &kernel_node) override { return; }
void InitKernel(const CNodePtr &) override { return; }
std::vector<KernelAttr> GetOpSupport() override;
protected:
void InitSizeLists() { return; }
std::vector<KernelAttr> GetOpSupport() override;
void InitSizeLists() const { return; }
private:
template <typename T>

View File

@ -121,11 +121,12 @@ class FusedPushWeightKernelMod : public DeprecatedNativeCpuKernelMod {
init_func_(this, kernel_node);
}
void InitKernel(const CNodePtr &kernel_node) override { return; }
void InitKernel(const CNodePtr &) override { return; }
std::vector<KernelAttr> GetOpSupport() override;
protected:
void InitSizeLists() { return; }
std::vector<KernelAttr> GetOpSupport() override;
private:
bool BuildPushWeightReq(std::shared_ptr<fl::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {

View File

@ -18,7 +18,7 @@
namespace mindspore {
namespace kernel {
bool GetKeysKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool GetKeysKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) {
MS_LOG(INFO) << "Launching client GetKeysKernelMod";
BuildGetKeysReq(fbb_);
@ -80,7 +80,7 @@ void GetKeysKernelMod::Init(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "Initialize GetKeys kernel successfully.";
}
void GetKeysKernelMod::InitKernel(const CNodePtr &kernel_node) { return; }
void GetKeysKernelMod::InitKernel(const CNodePtr &) { return; }
void GetKeysKernelMod::BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
MS_EXCEPTION_IF_NULL(fbb);
@ -95,7 +95,7 @@ void GetKeysKernelMod::BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb
}
bool GetKeysKernelMod::SavePublicKeyList(
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key) {
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key) const {
if (remote_public_key == nullptr) {
MS_LOG(EXCEPTION) << "Input remote_pubic_key is nullptr.";
}

View File

@ -32,27 +32,26 @@ class GetKeysKernelMod : public DeprecatedNativeCpuKernelMod {
GetKeysKernelMod() = default;
~GetKeysKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) override;
void Init(const CNodePtr &kernel_node) override;
void InitKernel(const CNodePtr &kernel_node) override;
void InitKernel(const CNodePtr &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
const std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
private:
void BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
bool SavePublicKeyList(
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key);
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key) const;
uint32_t rank_id_;
uint32_t server_num_;
uint32_t target_server_rank_;
uint32_t rank_id_{0};
uint32_t server_num_{0};
uint32_t target_server_rank_{0};
std::string fl_id_;
std::shared_ptr<fl::FBBuilder> fbb_;
};

Some files were not shown because too many files have changed in this diff Show More