Merge pull request !39135 from Bert0108/clean_code_kldivlossgrad
This commit is contained in:
i-robot 2022-07-29 09:59:19 +00:00 committed by Gitee
commit 7be107efec
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 14 additions and 17 deletions

View File

@ -34,7 +34,6 @@ bool KLDivLossGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
auto kernel_ptr = std::dynamic_pointer_cast<ops::KLDivLossGrad>(base_operator); auto kernel_ptr = std::dynamic_pointer_cast<ops::KLDivLossGrad>(base_operator);
if (!kernel_ptr) { if (!kernel_ptr) {
MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!"; MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!";
return false;
} }
kernel_name_ = kernel_ptr->name(); kernel_name_ = kernel_ptr->name();
reductionMode_ = kernel_ptr->get_reduction(); reductionMode_ = kernel_ptr->get_reduction();
@ -76,21 +75,21 @@ int KLDivLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
std::vector<int64_t> input_grad_shape = inputs[kIndex0]->GetShapeVector(); std::vector<int64_t> input_grad_shape = inputs[kIndex0]->GetShapeVector();
if (input_grad_shape.size() >= 1) { if (input_grad_shape.size() >= 1) {
for (size_t i = 0; i < input_grad_shape.size(); ++i) { for (size_t i = 0; i < input_grad_shape.size(); ++i) {
input_grad_shape_size_ *= input_grad_shape[i]; input_grad_shape_size_ *= LongToSize(input_grad_shape[i]);
} }
} }
std::vector<int64_t> input_x_shape = inputs[kIndex1]->GetShapeVector(); std::vector<int64_t> input_x_shape = inputs[kIndex1]->GetShapeVector();
if (input_x_shape.size() >= 1) { if (input_x_shape.size() >= 1) {
for (size_t i = 0; i < input_x_shape.size(); ++i) { for (size_t i = 0; i < input_x_shape.size(); ++i) {
input_x_shape_size_ *= input_x_shape[i]; input_x_shape_size_ *= LongToSize(input_x_shape[i]);
} }
} }
std::vector<int64_t> input_target_shape = inputs[kIndex2]->GetShapeVector(); std::vector<int64_t> input_target_shape = inputs[kIndex2]->GetShapeVector();
if (input_target_shape.size() >= 1) { if (input_target_shape.size() >= 1) {
for (size_t i = 0; i < input_target_shape.size(); ++i) { for (size_t i = 0; i < input_target_shape.size(); ++i) {
input_target_shape_size_ *= input_target_shape[i]; input_target_shape_size_ *= LongToSize(input_target_shape[i]);
} }
} }
@ -105,7 +104,7 @@ std::vector<KernelAttr> KLDivLossGradCpuKernelMod::GetOpSupport() {
return support_list; return support_list;
} }
bool KLDivLossGradCpuKernelMod::CheckParams() { bool KLDivLossGradCpuKernelMod::CheckParams() const {
// for kl div, shape size of input 1 and input 2 must be the same // for kl div, shape size of input 1 and input 2 must be the same
if (input_target_shape_size_ != input_x_shape_size_) { if (input_target_shape_size_ != input_x_shape_size_) {
MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_ MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_
@ -131,8 +130,7 @@ bool KLDivLossGradCpuKernelMod::CheckParams() {
} }
template <typename T> template <typename T>
bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
auto *input_grad = reinterpret_cast<T *>(inputs[kIndex0]->addr); auto *input_grad = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto *input_target = reinterpret_cast<T *>(inputs[kIndex2]->addr); auto *input_target = reinterpret_cast<T *>(inputs[kIndex2]->addr);
@ -142,7 +140,7 @@ bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1); Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_y(y, input_x_shape_size_, 1); Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_y(y, input_x_shape_size_, 1);
float coefficient = -1.0; double coefficient = -1.0;
if (reductionMode_ == ops::kMean) { if (reductionMode_ == ops::kMean) {
coefficient /= input_x_shape_size_; coefficient /= input_x_shape_size_;
} else if (reductionMode_ == ops::kBatchMean) { } else if (reductionMode_ == ops::kBatchMean) {
@ -159,7 +157,7 @@ bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
if (static_cast<float>(array_target[i]) <= 0.0) { if (static_cast<double>(array_target[i]) <= 0.0) {
array_y[i] = static_cast<T>(0); array_y[i] = static_cast<T>(0);
} else { } else {
array_y[i] *= (static_cast<T>(coefficient) * static_cast<T>(bcast)); array_y[i] *= (static_cast<T>(coefficient) * static_cast<T>(bcast));

View File

@ -48,15 +48,14 @@ class KLDivLossGradCpuKernelMod : public NativeCpuKernelMod {
private: private:
template <typename T> template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs); const std::vector<AddressPtr> &outputs);
bool CheckParams(); bool CheckParams() const;
using KLDivLossGradFunc = std::function<bool(KLDivLossGradCpuKernelMod *, const std::vector<AddressPtr> &, using KLDivLossGradFunc = std::function<bool(KLDivLossGradCpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>; const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
private:
static std::vector<std::pair<KernelAttr, KLDivLossGradFunc>> func_list_; static std::vector<std::pair<KernelAttr, KLDivLossGradFunc>> func_list_;
KLDivLossGradFunc kernel_func_; KLDivLossGradFunc kernel_func_;
std::string reductionMode_ = "mean"; std::string reductionMode_ = "mean";

View File

@ -52,13 +52,13 @@ TypePtr KLDivLossGradInferType(const PrimitivePtr &prim, const std::vector<Abstr
auto input_grad_type = input_args[kInputIndex0]->BuildType(); auto input_grad_type = input_args[kInputIndex0]->BuildType();
auto input_x_type = input_args[kInputIndex1]->BuildType(); auto input_x_type = input_args[kInputIndex1]->BuildType();
auto input_target_type = input_args[kInputIndex2]->BuildType(); auto input_target_type = input_args[kInputIndex2]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name); (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name);
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
types.emplace("grad", input_grad_type); (void)types.emplace("grad", input_grad_type);
types.emplace("x", input_x_type); (void)types.emplace("x", input_x_type);
types.emplace("target", input_target_type); (void)types.emplace("target", input_target_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
return input_x_type; return input_x_type;
} }