forked from mindspore-Ecosystem/mindspore
clean code
This commit is contained in:
parent
4c9b6b06cb
commit
a610b338e1
|
@ -34,7 +34,6 @@ class AdamCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true).AddOutInRef(0, 0)};
|
||||
return support_list;
|
||||
|
|
|
@ -32,7 +32,6 @@ class AdamDeltaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class AddcdivCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class AddcmulCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -33,7 +33,6 @@ class AdjustContrastv2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -44,7 +44,6 @@ class AdjustHueCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -33,11 +33,10 @@ class AdjustSaturationCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
TypeId input_type_{kTypeUnknown};
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,6 @@ class AllGatherCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -35,7 +35,6 @@ class AllReduceCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -37,7 +37,6 @@ class AngleCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -44,7 +44,6 @@ class ApplyAdaMaxCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -43,7 +43,6 @@ class ApplyAdadeltaCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -62,6 +61,7 @@ class ApplyAdadeltaCpuKernelMod : public NativeCpuKernelMod {
|
|||
return support_list;
|
||||
}
|
||||
|
||||
protected:
|
||||
int CheckInputShape(const std::vector<KernelTensorPtr> &inputs);
|
||||
int CheckShapeSize(std::vector<int64_t> var_shape, std::vector<int64_t> lr_shape);
|
||||
|
||||
|
|
|
@ -35,7 +35,6 @@ class ApplyAdagradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -42,7 +42,6 @@ class ApplyAdagradDACpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -46,7 +46,6 @@ class ApplyAdagradV2CpuKernelMod : public NativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,6 @@ class ApplyAdamWithAmsgradCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,7 +39,6 @@ class ApplyFtrlCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -38,7 +38,6 @@ class ApplyGradientDescentCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
|
|
@ -35,7 +35,6 @@ class BACKEND_EXPORT ApplyMomentumCpuKernelMod : public DeprecatedNativeCpuKerne
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -43,7 +43,6 @@ class ApplyProximalAdagradCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -38,7 +38,6 @@ class ApproximateEqualCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -65,7 +65,7 @@ bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
for (int64_t j = 0; j < num_after_axis_; j++) {
|
||||
int64_t src_index_j = src_index_i + j;
|
||||
for (int64_t k = 0; k < dim_axis_; k++) {
|
||||
int64_t src_index_k = k * num_after_axis_ + src_index_j;
|
||||
auto src_index_k = LongToSize(k * num_after_axis_ + src_index_j);
|
||||
array_axis[k] = static_cast<float>(input[src_index_k]);
|
||||
}
|
||||
auto max_ops = std::max_element(array_axis.begin(), array_axis.end());
|
||||
|
|
|
@ -42,7 +42,6 @@ class ArgminCpuKernelMod : public NativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -37,11 +37,22 @@ namespace {
|
|||
constexpr float kMaxSubSerialSize = 10000.0;
|
||||
constexpr float kMaxPowSerialSize = 700.0;
|
||||
|
||||
constexpr auto kAdd = "Add";
|
||||
constexpr auto kAddV2 = "AddV2";
|
||||
constexpr auto kSub = "Sub";
|
||||
constexpr auto kMul = "Mul";
|
||||
constexpr auto kRealDiv = "RealDiv";
|
||||
constexpr auto kAssignAdd = "AssignAdd";
|
||||
constexpr auto kAssignSub = "AssignSub";
|
||||
constexpr auto kDiv = "Div";
|
||||
constexpr auto kDivNoNan = "DivNoNan";
|
||||
constexpr auto kPow = "Pow";
|
||||
constexpr auto kFloorDiv = "FloorDiv";
|
||||
constexpr auto kMod = "Mod";
|
||||
constexpr auto kFloorMod = "FloorMod";
|
||||
constexpr auto kSquaredDifference = "SquaredDifference";
|
||||
constexpr auto kXlogy = "Xlogy";
|
||||
constexpr auto kAtan2 = "Atan2";
|
||||
|
||||
template <typename T>
|
||||
void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_t delta_1, size_t delta_2) {
|
||||
|
@ -157,9 +168,9 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
MS_LOG(WARNING) << kernel_name_ << " output shape contain 0, output_shape: " << output_shape_;
|
||||
return true;
|
||||
}
|
||||
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
|
||||
if (kernel_name_ == kAssignAdd) {
|
||||
AssignAdd(input1, input2, output);
|
||||
} else if (kernel_name_ == prim::kPrimAssignSub->name()) {
|
||||
} else if (kernel_name_ == kAssignSub) {
|
||||
AssignSub(input1, input2, output);
|
||||
} else {
|
||||
compute_func_(this, input1, input2, output);
|
||||
|
@ -169,39 +180,38 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
|
||||
private:
|
||||
void InitComputeFunc() {
|
||||
if (kernel_name_ == prim::kPrimAssignAdd->name() || kernel_name_ == prim::kPrimAssignSub->name()) {
|
||||
if (kernel_name_ == kAssignAdd || kernel_name_ == kAssignSub) {
|
||||
return;
|
||||
}
|
||||
string dtype_desc;
|
||||
static std::unordered_map<std::string, TypeComputeFunc> arithmeticMathFuncMap;
|
||||
if constexpr (!((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>))) {
|
||||
dtype_desc = "real data";
|
||||
arithmeticMathFuncMap = {{prim::kPrimAdd->name(), &ArithmeticCpuTypeFunc<T>::Add},
|
||||
{prim::kPrimAddV2->name(), &ArithmeticCpuTypeFunc<T>::AddV2},
|
||||
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
|
||||
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
|
||||
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::Div},
|
||||
{prim::kPrimDivNoNan->name(), &ArithmeticCpuTypeFunc<T>::DivNoNan},
|
||||
{prim::kPrimMod->name(), &ArithmeticCpuTypeFunc<T>::Mod},
|
||||
{prim::kPrimFloorMod->name(), &ArithmeticCpuTypeFunc<T>::FloorMod},
|
||||
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::Pow},
|
||||
{prim::kPrimFloorDiv->name(), &ArithmeticCpuTypeFunc<T>::FloorDiv},
|
||||
{prim::kPrimAtan2->name(), &ArithmeticCpuTypeFunc<T>::Atan2},
|
||||
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDiv},
|
||||
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuTypeFunc<T>::SquaredDifference},
|
||||
{prim::kPrimXlogy->name(), &ArithmeticCpuTypeFunc<T>::Xlogy}};
|
||||
arithmeticMathFuncMap = {{kAdd, &ArithmeticCpuTypeFunc<T>::Add},
|
||||
{kAddV2, &ArithmeticCpuTypeFunc<T>::AddV2},
|
||||
{kSub, &ArithmeticCpuTypeFunc<T>::Sub},
|
||||
{kMul, &ArithmeticCpuTypeFunc<T>::Mul},
|
||||
{kDiv, &ArithmeticCpuTypeFunc<T>::Div},
|
||||
{kDivNoNan, &ArithmeticCpuTypeFunc<T>::DivNoNan},
|
||||
{kMod, &ArithmeticCpuTypeFunc<T>::Mod},
|
||||
{kFloorMod, &ArithmeticCpuTypeFunc<T>::FloorMod},
|
||||
{kPow, &ArithmeticCpuTypeFunc<T>::Pow},
|
||||
{kFloorDiv, &ArithmeticCpuTypeFunc<T>::FloorDiv},
|
||||
{kAtan2, &ArithmeticCpuTypeFunc<T>::Atan2},
|
||||
{kRealDiv, &ArithmeticCpuTypeFunc<T>::RealDiv},
|
||||
{kSquaredDifference, &ArithmeticCpuTypeFunc<T>::SquaredDifference},
|
||||
{kXlogy, &ArithmeticCpuTypeFunc<T>::Xlogy}};
|
||||
} else {
|
||||
dtype_desc = "complex data";
|
||||
arithmeticMathFuncMap = {
|
||||
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuTypeFunc<T>::SquaredDifferenceComplex},
|
||||
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
|
||||
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::DivComplex},
|
||||
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDivComplex},
|
||||
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
|
||||
{prim::kPrimDivNoNan->name(), &ArithmeticCpuTypeFunc<T>::DivNoNan},
|
||||
{prim::kPrimAddV2->name(), &ArithmeticCpuTypeFunc<T>::AddV2},
|
||||
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::PowComplex},
|
||||
{prim::kPrimXlogy->name(), &ArithmeticCpuTypeFunc<T>::Xlogy}};
|
||||
arithmeticMathFuncMap = {{kSquaredDifference, &ArithmeticCpuTypeFunc<T>::SquaredDifferenceComplex},
|
||||
{kSub, &ArithmeticCpuTypeFunc<T>::Sub},
|
||||
{kDiv, &ArithmeticCpuTypeFunc<T>::DivComplex},
|
||||
{kRealDiv, &ArithmeticCpuTypeFunc<T>::RealDivComplex},
|
||||
{kMul, &ArithmeticCpuTypeFunc<T>::Mul},
|
||||
{kDivNoNan, &ArithmeticCpuTypeFunc<T>::DivNoNan},
|
||||
{kAddV2, &ArithmeticCpuTypeFunc<T>::AddV2},
|
||||
{kPow, &ArithmeticCpuTypeFunc<T>::PowComplex},
|
||||
{kXlogy, &ArithmeticCpuTypeFunc<T>::Xlogy}};
|
||||
}
|
||||
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'Arithmetic', only supports operators in " << Map2Str(arithmeticMathFuncMap)
|
||||
|
@ -810,7 +820,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
SpecializeArithFunc<complex128>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithFunc<bool>}}},
|
||||
{prim::kPrimDiv->name(),
|
||||
{kDiv,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
SpecializeArithFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
|
@ -839,7 +849,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeArithFunc<complex128>}}},
|
||||
{prim::kPrimDivNoNan->name(),
|
||||
{kDivNoNan,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpecializeArithFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -856,7 +866,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeArithFunc<complex128>}}},
|
||||
{prim::kPrimPow->name(),
|
||||
{kPow,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -910,7 +920,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeArithFunc<complex128>}}},
|
||||
{prim::kPrimFloorDiv->name(),
|
||||
{kFloorDiv,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
SpecializeArithFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
|
@ -927,14 +937,14 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimMod->name(),
|
||||
{kMod,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>}}},
|
||||
{prim::kPrimFloorMod->name(),
|
||||
{kFloorMod,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
|
@ -1013,7 +1023,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutInRef(0, 0),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimSquaredDifference->name(),
|
||||
{kSquaredDifference,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
|
@ -1034,7 +1044,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeArithFunc<complex128>}}},
|
||||
{prim::kPrimXlogy->name(),
|
||||
{kXlogy,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
|
@ -1051,12 +1061,12 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
|
|||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeArithFunc<complex128>}}},
|
||||
{prim::kPrimAtan2->name(),
|
||||
{kAtan2,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimAddV2->name(),
|
||||
{kAddV2,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
SpecializeArithFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
|
@ -1118,35 +1128,31 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sub,
|
|||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mul,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMul); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Div,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDiv->name()); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DivNoNan, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDivNoNan->name());
|
||||
});
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kDiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DivNoNan,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kDivNoNan); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Pow,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimPow->name()); });
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kPow); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RealDiv,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kRealDiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorDiv, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorDiv->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorDiv,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kFloorDiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mod,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimMod->name()); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorMod, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorMod->name());
|
||||
});
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMod); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorMod,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kFloorMod); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AssignAdd,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAssignAdd); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AssignSub,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAssignSub); });
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SquaredDifference, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimSquaredDifference->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SquaredDifference,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kSquaredDifference); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Xlogy,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimXlogy->name()); });
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kXlogy); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atan2,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimAtan2->name()); });
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAtan2); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AddV2,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimAddV2->name()); });
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAddV2); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,6 @@ class ArithmeticCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -34,6 +34,13 @@ using complex64 = std::complex<float>;
|
|||
using complex128 = std::complex<double>;
|
||||
|
||||
constexpr size_t kMaxLessSerialSize = 15000;
|
||||
constexpr auto kLess = "Less";
|
||||
constexpr auto kLessEqual = "LessEqual";
|
||||
constexpr auto kGreater = "Greater";
|
||||
constexpr auto kGreaterEqual = "GreaterEqual";
|
||||
constexpr auto kLogicalAnd = "LogicalAnd";
|
||||
constexpr auto kLogicalOr = "LogicalOr";
|
||||
constexpr auto kLogicalXor = "LogicalXor";
|
||||
constexpr auto kEqual = "Equal";
|
||||
constexpr auto kNotEqual = "NotEqual";
|
||||
|
||||
|
@ -85,16 +92,16 @@ class ArithLogicCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
}
|
||||
static std::unordered_map<std::string, TypeComputeFunc> arithmetic_logic_func_map;
|
||||
if constexpr (!((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>))) {
|
||||
arithmetic_logic_func_map = {{prim::kPrimGreater->name(), &ArithLogicCpuTypeFunc<T>::Greater},
|
||||
{prim::kPrimGreaterEqual->name(), &ArithLogicCpuTypeFunc<T>::GreaterEqual},
|
||||
{prim::kPrimLogicalAnd->name(), &ArithLogicCpuTypeFunc<T>::LogicalAnd},
|
||||
{prim::kPrimLessEqual->name(), &ArithLogicCpuTypeFunc<T>::LessEqual},
|
||||
{prim::kPrimLogicalOr->name(), &ArithLogicCpuTypeFunc<T>::LogicalOr},
|
||||
{prim::kPrimLogicalXor->name(), &ArithLogicCpuTypeFunc<T>::LogicalXor},
|
||||
{prim::kPrimLess->name(), &ArithLogicCpuTypeFunc<T>::Less},
|
||||
{prim::kPrimNotEqual->name(), &ArithLogicCpuTypeFunc<T>::NotEqual}};
|
||||
arithmetic_logic_func_map = {{kGreater, &ArithLogicCpuTypeFunc<T>::Greater},
|
||||
{kGreaterEqual, &ArithLogicCpuTypeFunc<T>::GreaterEqual},
|
||||
{kLogicalAnd, &ArithLogicCpuTypeFunc<T>::LogicalAnd},
|
||||
{kLessEqual, &ArithLogicCpuTypeFunc<T>::LessEqual},
|
||||
{kLogicalOr, &ArithLogicCpuTypeFunc<T>::LogicalOr},
|
||||
{kLogicalXor, &ArithLogicCpuTypeFunc<T>::LogicalXor},
|
||||
{kLess, &ArithLogicCpuTypeFunc<T>::Less},
|
||||
{kNotEqual, &ArithLogicCpuTypeFunc<T>::NotEqual}};
|
||||
} else {
|
||||
arithmetic_logic_func_map = {{prim::kPrimNotEqual->name(), &ArithLogicCpuTypeFunc<T>::NotEqual}};
|
||||
arithmetic_logic_func_map = {{kNotEqual, &ArithLogicCpuTypeFunc<T>::NotEqual}};
|
||||
}
|
||||
if (arithmetic_logic_func_map.find(kernel_name_) == arithmetic_logic_func_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in " << Map2Str(arithmetic_logic_func_map)
|
||||
|
@ -114,7 +121,7 @@ class ArithLogicCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
void LessEqual(const T *input1, const T *input2, bool *out);
|
||||
void LogicalAnd(const T *input1, const T *input2, bool *out);
|
||||
void LogicalOr(const T *input1, const T *input2, bool *out);
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out);
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out) const;
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithLogicCpuTypeFunc *, const T *, const T *, bool *)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
|
@ -178,7 +185,7 @@ class ArithComplexLogicCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
<< dtype_ << ", and the type of 'input2': " << dtype_1;
|
||||
}
|
||||
static const std::unordered_map<std::string, ComplexTypeComputeFunc> arithmetic_logic_func_map{
|
||||
{prim::kPrimEqual->name(), &ArithComplexLogicCpuTypeFunc<T>::Equal}};
|
||||
{kEqual, &ArithComplexLogicCpuTypeFunc<T>::Equal}};
|
||||
if (arithmetic_logic_func_map.find(kernel_name_) == arithmetic_logic_func_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in " << Map2Str(arithmetic_logic_func_map)
|
||||
<< ", but got " << kernel_name_;
|
||||
|
@ -329,7 +336,7 @@ void ArithLogicCpuTypeFunc<T>::LogicalOr(const T *input1, const T *input2, bool
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) const {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -367,7 +374,7 @@ std::shared_ptr<DeprecatedCpuKernelFunc> SpecializeArithLogComplexFunc() {
|
|||
}
|
||||
using ArithLogicCpuFuncCreator = std::function<std::shared_ptr<DeprecatedCpuKernelFunc>()>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFuncCreator>>> kernel_attr_lists = {
|
||||
{prim::kPrimLess->name(),
|
||||
{kLess,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
|
@ -438,7 +445,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
|
|||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<complex128>}}},
|
||||
{prim::kPrimGreater->name(),
|
||||
{kGreater,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
|
@ -447,7 +454,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
|
|||
SpecializeArithLogFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>}}},
|
||||
{prim::kPrimGreaterEqual->name(),
|
||||
{kGreaterEqual,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
|
@ -462,7 +469,7 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
|
|||
SpecializeArithLogFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{prim::kPrimLessEqual->name(),
|
||||
{kLessEqual,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
|
@ -471,13 +478,13 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
|
|||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{prim::kPrimLogicalAnd->name(),
|
||||
{kLogicalAnd,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}},
|
||||
{prim::kPrimLogicalOr->name(),
|
||||
{kLogicalOr,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}},
|
||||
{prim::kPrimLogicalXor->name(),
|
||||
{kLogicalXor,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}}};
|
||||
} // namespace
|
||||
|
@ -540,30 +547,23 @@ std::vector<KernelAttr> ArithmeticComplexLogicCpuKernelMod::GetOpSupport() {
|
|||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Less, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLess->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Less,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLess); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Equal,
|
||||
[]() { return std::make_shared<ArithmeticComplexLogicCpuKernelMod>(kEqual); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, NotEqual,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kNotEqual); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Greater, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreater->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GreaterEqual, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreaterEqual->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LessEqual, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLessEqual->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalAnd, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalAnd->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalOr, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalOr->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalXor, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalXor->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Greater,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kGreater); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GreaterEqual,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kGreaterEqual); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LessEqual,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLessEqual); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalAnd,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLogicalAnd); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalOr,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLogicalOr); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalXor,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kLogicalXor); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,13 +44,13 @@ class ArithmeticLogicCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<DeprecatedCpuKernelFunc> func_obj_;
|
||||
std::string kernel_type_{"Unknown"};
|
||||
};
|
||||
|
||||
class ArithmeticComplexLogicCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
ArithmeticComplexLogicCpuKernelMod() = default;
|
||||
|
@ -68,7 +68,6 @@ class ArithmeticComplexLogicCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -115,7 +115,7 @@ void Sign(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Neg(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
void Neg(ArithmeticSelfCpuKernelFunc *, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = -in[i];
|
||||
|
@ -511,7 +511,7 @@ void Rsqrt(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t siz
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Softsign(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
void Softsign(ArithmeticSelfCpuKernelFunc *, const T *in, T *out, size_t size) {
|
||||
if constexpr ((std::is_same_v<T, uint8_t>) || (std::is_same_v<T, uint16_t>) || (std::is_same_v<T, uint32_t>) ||
|
||||
(std::is_same_v<T, uint64_t>)) {
|
||||
MS_LOG(EXCEPTION) << "'Softsign' cannot be instantiated.";
|
||||
|
|
|
@ -48,7 +48,6 @@ class ArithmeticSelfCpuKernelMod : public NativeCpuKernelMod {
|
|||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
@ -76,7 +75,6 @@ class IdentityCpuKernelMod : public NativeCpuKernelMod {
|
|||
return kernel_func_(inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class AssignCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -48,7 +48,6 @@ class BatchToSpaceNDCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -43,7 +43,6 @@ class BesselI0CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselI0Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
@ -78,7 +77,6 @@ class BesselI0eCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselI0eFunc(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -44,7 +44,6 @@ class BesselI1CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselI1Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
@ -78,7 +77,6 @@ class BesselI1eCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselI1eFunc(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -46,7 +46,6 @@ class BesselJ0CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselJ0Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -46,7 +46,6 @@ class BesselJ1CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselJ1Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -43,7 +43,6 @@ class BesselK0CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselK0Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
@ -77,7 +76,6 @@ class BesselK0eCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselK0eFunc(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -43,7 +43,6 @@ class BesselK1CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselK1Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
@ -77,7 +76,6 @@ class BesselK1eCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselK1eFunc(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -45,7 +45,6 @@ class BesselY0CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselY0Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -45,7 +45,6 @@ class BesselY1CpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
static void BesselY1Func(const T *input, T *output, size_t start, size_t end);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class BinaryCrossEntropyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -35,7 +35,6 @@ class BinaryCrossEntropyGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -35,7 +35,6 @@ class BlackmanWindowCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,6 @@ class BoundingBoxDecodeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,7 +39,6 @@ class BoundingBoxEncodeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -47,7 +47,6 @@ class BroadcastToCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
|
||||
void CheckArgs();
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -35,7 +35,6 @@ class BucketizeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
bool BucketizeCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -37,7 +37,6 @@ class CheckNumericsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
|
|
@ -38,7 +38,6 @@ class CheckValidCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class CholeskyInverseCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,7 +39,6 @@ class CoalesceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,6 @@ class ComplexCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -42,7 +42,6 @@ class ConcatCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -52,7 +52,7 @@ void ConcatOffsetCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
template <typename T>
|
||||
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
|
||||
auto node_ = cnode_ptr_.lock();
|
||||
|
|
|
@ -37,12 +37,11 @@ class ConcatOffsetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using ConcatOffsetFunc = std::function<bool(ConcatOffsetCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, ConcatOffsetFunc>> func_list_;
|
||||
|
|
|
@ -48,7 +48,7 @@ std::vector<KernelAttr> NativeCpuKernelMod::GetAllSupportedList(const std::strin
|
|||
return support_map_[kernel_name];
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) {
|
||||
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) const {
|
||||
static std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
|
||||
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
|
||||
std::vector<KernelAttr> support_kernel_attrs;
|
||||
|
@ -134,7 +134,7 @@ void DeprecatedNativeCpuKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
InitInputOutputSize(kernel_node);
|
||||
}
|
||||
|
||||
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetInputDtypes(const CNodePtr &kernel_node) {
|
||||
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetInputDtypes(const CNodePtr &kernel_node) const {
|
||||
std::vector<TypeId> input_types;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
|
@ -144,7 +144,7 @@ std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetInputDtypes(const CNodePtr
|
|||
return input_types;
|
||||
}
|
||||
|
||||
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetOutputDtypes(const CNodePtr &kernel_node) {
|
||||
std::vector<TypeId> DeprecatedNativeCpuKernelMod::GetOutputDtypes(const CNodePtr &kernel_node) const {
|
||||
std::vector<TypeId> output_types;
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
|
|
|
@ -167,7 +167,7 @@ class BACKEND_EXPORT NativeCpuKernelMod : public CpuKernelMod {
|
|||
|
||||
private:
|
||||
std::vector<KernelAttr> GetAllSupportedList(const std::string &kernel_name);
|
||||
std::vector<KernelAttr> GetSupportFromOpLib(const std::string &kernel_name);
|
||||
std::vector<KernelAttr> GetSupportFromOpLib(const std::string &kernel_name) const;
|
||||
static mindspore::HashMap<std::string, std::vector<KernelAttr>> support_map_;
|
||||
};
|
||||
|
||||
|
@ -208,8 +208,8 @@ class BACKEND_EXPORT DeprecatedNativeCpuKernelMod : public NativeCpuKernelMod {
|
|||
}
|
||||
|
||||
private:
|
||||
std::vector<TypeId> GetInputDtypes(const CNodePtr &kernel_node);
|
||||
std::vector<TypeId> GetOutputDtypes(const CNodePtr &kernel_node);
|
||||
std::vector<TypeId> GetInputDtypes(const CNodePtr &kernel_node) const;
|
||||
std::vector<TypeId> GetOutputDtypes(const CNodePtr &kernel_node) const;
|
||||
};
|
||||
|
||||
class DeprecatedCpuKernelFunc {
|
||||
|
|
|
@ -53,7 +53,6 @@ class CropAndResizeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -65,7 +65,6 @@ class CropAndResizeGradBoxesCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
|
||||
void OutputZeroing(const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class CrossCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -48,7 +48,6 @@ class CTCGreedyDecoderCpuKernelMod : public NativeCpuKernelMod, public MatchKern
|
|||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -38,7 +38,6 @@ class CTCLossCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -48,7 +48,6 @@ class DataFormatDimMapCpuKernelMod : public NativeCpuKernelMod, public MatchKern
|
|||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,7 +39,6 @@ class DataFormatVecPermuteCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
@ -36,7 +35,6 @@ class DebugCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -67,7 +67,6 @@ class DeformableOffsetsGradCpuKernelMod : public NativeCpuKernelMod,
|
|||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::OpSupport(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -38,7 +38,6 @@ class DenseToCSRSparseMatrixCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -38,7 +38,6 @@ class DenseToDenseSetOperationCpuKernelMod : public DeprecatedNativeCpuKernelMod
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class DepthToSpaceCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -58,7 +58,7 @@ bool DynamicAssignCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
|
|||
|
||||
template <typename T>
|
||||
void DynamicAssignCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const std::vector<kernel::AddressPtr> &) {
|
||||
auto node = node_wpt_.lock();
|
||||
if (!node) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_(kernel_node) is expired. Error no: " << node;
|
||||
|
|
|
@ -36,12 +36,11 @@ class DynamicAssignCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &);
|
||||
|
||||
size_t batch_size_{1};
|
||||
TypeId input_x_dtype_{kTypeUnknown};
|
||||
|
|
|
@ -34,7 +34,6 @@ class CholeskyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -37,7 +37,6 @@ class CholeskySolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,7 +39,6 @@ class EigCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -39,7 +39,6 @@ class EighCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
}
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -43,7 +43,6 @@ class ExpandCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <size_t RANK, typename T>
|
||||
bool ExpandCalculate(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -50,7 +50,7 @@ void LUCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *ro
|
|||
}
|
||||
}
|
||||
|
||||
void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) const {
|
||||
constexpr size_t pivot_min_dim = 1;
|
||||
if (shape.size() < pivot_min_dim) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid.";
|
||||
|
@ -99,14 +99,14 @@ void LUCpuKernelMod::InitIOSize(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
|
||||
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) const {
|
||||
const T *pered_lu_value = lu_value + per_value[i] * SizeToInt(lu_col_) + SizeToInt(j);
|
||||
return *pered_lu_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
|
||||
size_t rows) {
|
||||
size_t rows) const {
|
||||
T max_major_value = static_cast<T>(kZeroThreshold);
|
||||
size_t max_major_index = 0;
|
||||
for (size_t i = k; i < rows; ++i) {
|
||||
|
@ -126,7 +126,7 @@ bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_v
|
|||
|
||||
template <typename T>
|
||||
void LUCpuKernelMod::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
|
||||
const T &value) {
|
||||
const T &value) const {
|
||||
T *per_lu_value = lu_value + per_value[i] * SizeToInt(lu_col_) + SizeToInt(j);
|
||||
*per_lu_value = value;
|
||||
}
|
||||
|
|
|
@ -34,16 +34,15 @@ class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
|
||||
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) const;
|
||||
template <typename T>
|
||||
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
|
||||
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows) const;
|
||||
template <typename T>
|
||||
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
|
||||
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value) const;
|
||||
template <typename T>
|
||||
void InitIOSize(const CNodePtr &kernel_node);
|
||||
template <typename T>
|
||||
|
@ -59,7 +58,7 @@ class LUCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
|
||||
|
||||
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) const;
|
||||
size_t batch_size_{1};
|
||||
size_t a_row_{1};
|
||||
size_t a_col_{1};
|
||||
|
|
|
@ -35,7 +35,6 @@ class LUSolverCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -35,7 +35,6 @@ class MatrixTriangularSolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -34,7 +34,6 @@ class QRCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,6 @@ class RandomPoissonCpuKernelMod : public NativeCpuKernelMod, public MatchKernelH
|
|||
}
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -62,7 +62,6 @@ class EltWiseGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -36,7 +36,6 @@ class EluGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -34,7 +34,6 @@ class EmbeddingLookUpCommGradCpuKernelMod : public DeprecatedNativeCpuKernelMod
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
|
||||
|
|
|
@ -35,7 +35,6 @@ class BACKEND_EXPORT EmbeddingLookUpCpuKernelMod : public DeprecatedNativeCpuKer
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -44,6 +43,7 @@ class BACKEND_EXPORT EmbeddingLookUpCpuKernelMod : public DeprecatedNativeCpuKer
|
|||
return support_list;
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ class EnvironCreateCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitKernel(const CNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeInt64)};
|
||||
return support_list;
|
||||
|
|
|
@ -32,7 +32,6 @@ class EnvironDestroyAllCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitKernel(const CNodePtr &node);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeBool)};
|
||||
return support_list;
|
||||
|
|
|
@ -32,7 +32,6 @@ class EnvironGetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitKernel(const CNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
|
|
|
@ -33,7 +33,6 @@ class EnvironSetCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitKernel(const CNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
|
|
|
@ -61,7 +61,7 @@ bool Expm1CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
|
|||
|
||||
template <typename T>
|
||||
void Expm1CpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const std::vector<kernel::AddressPtr> &outputs) const {
|
||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t elem_num = inputs[0]->size / sizeof(T);
|
||||
|
|
|
@ -37,7 +37,6 @@ class Expm1CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
@ -50,7 +49,7 @@ class Expm1CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) const;
|
||||
|
||||
TypeId input_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
|
|
@ -38,7 +38,6 @@ class EyeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -34,7 +34,6 @@ class FillDiagonalCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -41,7 +41,6 @@ class FillV2CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
void CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace kernel {
|
|||
constexpr int iv_vec_len = 16;
|
||||
constexpr int salt_len = 32;
|
||||
|
||||
bool ExchangeKeysKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool ExchangeKeysKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) {
|
||||
MS_LOG(INFO) << "Launching client ExchangeKeysKernelMod";
|
||||
if (!BuildExchangeKeysReq(fbb_)) {
|
||||
|
@ -81,7 +81,7 @@ void ExchangeKeysKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
MS_LOG(INFO) << "Initialize ExchangeKeys kernel successfully.";
|
||||
}
|
||||
|
||||
void ExchangeKeysKernelMod::InitKernel(const CNodePtr &kernel_node) { return; }
|
||||
void ExchangeKeysKernelMod::InitKernel(const CNodePtr &) { return; }
|
||||
|
||||
bool ExchangeKeysKernelMod::BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
|
@ -128,7 +128,7 @@ bool ExchangeKeysKernelMod::BuildExchangeKeysReq(const std::shared_ptr<fl::FBBui
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ExchangeKeysKernelMod::GetPubicKeyBytes() {
|
||||
std::vector<uint8_t> ExchangeKeysKernelMod::GetPubicKeyBytes() const {
|
||||
// generate private key of secret
|
||||
armour::PrivateKey *sPriKeyPtr = armour::KeyAgreement::GeneratePrivKey();
|
||||
fl::worker::FLWorker::GetInstance().set_secret_pk(sPriKeyPtr);
|
||||
|
@ -164,7 +164,7 @@ std::vector<uint8_t> ExchangeKeysKernelMod::GetPubicKeyBytes() {
|
|||
}
|
||||
|
||||
std::vector<KernelAttr> ExchangeKeysKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
|
||||
const std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,26 +32,25 @@ class ExchangeKeysKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
ExchangeKeysKernelMod() = default;
|
||||
~ExchangeKeysKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) override;
|
||||
|
||||
void Init(const CNodePtr &kernel_node) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
void InitKernel(const CNodePtr &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
bool BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
|
||||
std::vector<uint8_t> GetPubicKeyBytes();
|
||||
std::vector<uint8_t> GetPubicKeyBytes() const;
|
||||
|
||||
uint32_t rank_id_;
|
||||
uint32_t server_num_;
|
||||
uint32_t target_server_rank_;
|
||||
uint32_t rank_id_{0};
|
||||
uint32_t server_num_{0};
|
||||
uint32_t target_server_rank_{0};
|
||||
std::string fl_id_;
|
||||
std::shared_ptr<fl::FBBuilder> fbb_;
|
||||
armour::PrivateKey *secret_prikey_;
|
||||
armour::PrivateKey *secret_prikey_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -144,11 +144,12 @@ class FusedPullWeightKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
init_func_(this, kernel_node);
|
||||
}
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override { return; }
|
||||
void InitKernel(const CNodePtr &) override { return; }
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() { return; }
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
void InitSizeLists() const { return; }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
|
|
@ -121,11 +121,12 @@ class FusedPushWeightKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
init_func_(this, kernel_node);
|
||||
}
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override { return; }
|
||||
void InitKernel(const CNodePtr &) override { return; }
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() { return; }
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
bool BuildPushWeightReq(std::shared_ptr<fl::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool GetKeysKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool GetKeysKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) {
|
||||
MS_LOG(INFO) << "Launching client GetKeysKernelMod";
|
||||
BuildGetKeysReq(fbb_);
|
||||
|
@ -80,7 +80,7 @@ void GetKeysKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
MS_LOG(INFO) << "Initialize GetKeys kernel successfully.";
|
||||
}
|
||||
|
||||
void GetKeysKernelMod::InitKernel(const CNodePtr &kernel_node) { return; }
|
||||
void GetKeysKernelMod::InitKernel(const CNodePtr &) { return; }
|
||||
|
||||
void GetKeysKernelMod::BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
|
@ -95,7 +95,7 @@ void GetKeysKernelMod::BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb
|
|||
}
|
||||
|
||||
bool GetKeysKernelMod::SavePublicKeyList(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key) {
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key) const {
|
||||
if (remote_public_key == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Input remote_pubic_key is nullptr.";
|
||||
}
|
||||
|
|
|
@ -32,27 +32,26 @@ class GetKeysKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
GetKeysKernelMod() = default;
|
||||
~GetKeysKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) override;
|
||||
|
||||
void Init(const CNodePtr &kernel_node) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
void InitKernel(const CNodePtr &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
|
||||
const std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
void BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
|
||||
bool SavePublicKeyList(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key);
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientPublicKeys>> *remote_public_key) const;
|
||||
|
||||
uint32_t rank_id_;
|
||||
uint32_t server_num_;
|
||||
uint32_t target_server_rank_;
|
||||
uint32_t rank_id_{0};
|
||||
uint32_t server_num_{0};
|
||||
uint32_t target_server_rank_{0};
|
||||
std::string fl_id_;
|
||||
std::shared_ptr<fl::FBBuilder> fbb_;
|
||||
};
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue