Merge pull request !39135 from Bert0108/clean_code_kldivlossgrad
This commit is contained in:
i-robot 2022-07-29 09:59:19 +00:00 committed by Gitee
commit 7be107efec
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 14 additions and 17 deletions

View File

@ -34,7 +34,6 @@ bool KLDivLossGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
auto kernel_ptr = std::dynamic_pointer_cast<ops::KLDivLossGrad>(base_operator);
if (!kernel_ptr) {
MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
reductionMode_ = kernel_ptr->get_reduction();
@ -76,21 +75,21 @@ int KLDivLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
std::vector<int64_t> input_grad_shape = inputs[kIndex0]->GetShapeVector();
if (input_grad_shape.size() >= 1) {
for (size_t i = 0; i < input_grad_shape.size(); ++i) {
input_grad_shape_size_ *= input_grad_shape[i];
input_grad_shape_size_ *= LongToSize(input_grad_shape[i]);
}
}
std::vector<int64_t> input_x_shape = inputs[kIndex1]->GetShapeVector();
if (input_x_shape.size() >= 1) {
for (size_t i = 0; i < input_x_shape.size(); ++i) {
input_x_shape_size_ *= input_x_shape[i];
input_x_shape_size_ *= LongToSize(input_x_shape[i]);
}
}
std::vector<int64_t> input_target_shape = inputs[kIndex2]->GetShapeVector();
if (input_target_shape.size() >= 1) {
for (size_t i = 0; i < input_target_shape.size(); ++i) {
input_target_shape_size_ *= input_target_shape[i];
input_target_shape_size_ *= LongToSize(input_target_shape[i]);
}
}
@ -105,7 +104,7 @@ std::vector<KernelAttr> KLDivLossGradCpuKernelMod::GetOpSupport() {
return support_list;
}
bool KLDivLossGradCpuKernelMod::CheckParams() {
bool KLDivLossGradCpuKernelMod::CheckParams() const {
// for kl div, shape size of input 1 and input 2 must be the same
if (input_target_shape_size_ != input_x_shape_size_) {
MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_
@ -131,8 +130,7 @@ bool KLDivLossGradCpuKernelMod::CheckParams() {
}
template <typename T>
bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto *input_grad = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto *input_target = reinterpret_cast<T *>(inputs[kIndex2]->addr);
@ -142,7 +140,7 @@ bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_target(input_target, input_target_shape_size_, 1);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> array_y(y, input_x_shape_size_, 1);
float coefficient = -1.0;
double coefficient = -1.0;
if (reductionMode_ == ops::kMean) {
coefficient /= input_x_shape_size_;
} else if (reductionMode_ == ops::kBatchMean) {
@ -159,7 +157,7 @@ bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
if (static_cast<float>(array_target[i]) <= 0.0) {
if (static_cast<double>(array_target[i]) <= 0.0) {
array_y[i] = static_cast<T>(0);
} else {
array_y[i] *= (static_cast<T>(coefficient) * static_cast<T>(bcast));

View File

@ -48,15 +48,14 @@ class KLDivLossGradCpuKernelMod : public NativeCpuKernelMod {
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
bool CheckParams();
bool CheckParams() const;
using KLDivLossGradFunc = std::function<bool(KLDivLossGradCpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
private:
static std::vector<std::pair<KernelAttr, KLDivLossGradFunc>> func_list_;
KLDivLossGradFunc kernel_func_;
std::string reductionMode_ = "mean";

View File

@ -52,13 +52,13 @@ TypePtr KLDivLossGradInferType(const PrimitivePtr &prim, const std::vector<Abstr
auto input_grad_type = input_args[kInputIndex0]->BuildType();
auto input_x_type = input_args[kInputIndex1]->BuildType();
auto input_target_type = input_args[kInputIndex2]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name);
std::map<std::string, TypePtr> types;
types.emplace("grad", input_grad_type);
types.emplace("x", input_x_type);
types.emplace("target", input_target_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
(void)types.emplace("grad", input_grad_type);
(void)types.emplace("x", input_x_type);
(void)types.emplace("target", input_target_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
return input_x_type;
}