!29831 opt kernel code for code warning

Merge pull request !29831 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2022-03-18 08:06:37 +00:00 committed by Gitee
commit b8a0fe51ca
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 12 additions and 21 deletions

View File

@ -431,29 +431,29 @@ template <typename T>
void ArithmeticCpuTypeFunc<T>::Pow(const T *input1, const T *input2, T *out) {
if constexpr (std::is_same_v<T, float>) {
auto is_power_single = [this]() {
bool is_power_single = false;
bool is_power_single_inner = false;
if (input_shape1_.size() == input_shape2_.size()) {
is_power_single = true;
is_power_single_inner = true;
for (size_t i = 0; i < input_shape1_.size(); ++i) {
if (input_shape1_[i] != input_shape2_[i]) {
is_power_single = false;
is_power_single_inner = false;
break;
}
}
}
return is_power_single;
return is_power_single_inner;
};
if (op_para_.in_elements_num1_ == 1) {
auto task = [&](size_t start, size_t end) {
(void)Power(input1 + start, input2, out + start, end - start, 1, 0, true);
(void)Power(input1 + start, input2, out + start, SizeToInt(end - start), 1.0, 0.0, true);
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return;
}
if (is_power_single()) {
auto task = [&](size_t start, size_t end) {
(void)Power(input1 + start, input2 + start, out + start, end - start, 1, 0, false);
(void)Power(input1 + start, input2 + start, out + start, SizeToInt(end - start), 1.0, 0.0, false);
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return;

View File

@ -344,10 +344,7 @@ void ParallelLaunch(const std::vector<common::Task> &tasks, Content content) {
}
size_t task_num = tasks.size();
auto func = [&](void *, int task_id, float, float) {
tasks[task_id]();
return common::SUCCESS;
};
auto func = [&](void *, int task_id, float, float) { return tasks[task_id](); };
(void)thread_pool->ParallelLaunch(func, content, task_num);
}

View File

@ -46,7 +46,7 @@ void LUCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *ro
*col = shape.at(shape.size() - 1);
batch_size_ = lu_min_dim;
for (int batch = 0; batch < static_cast<int>(shape.size() - lu_reverse_row_dim); ++batch) {
batch_size_ *= shape.at(batch);
batch_size_ *= shape.at(SizeToInt(batch));
}
}
@ -100,7 +100,7 @@ void LUCpuKernelMod::InitIOSize(const CNodePtr &kernel_node) {
template <typename T>
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
const T *pered_lu_value = lu_value + per_value[i] * lu_col_ + j;
const T *pered_lu_value = lu_value + per_value[i] * SizeToInt(lu_col_) + SizeToInt(j);
return *pered_lu_value;
}
@ -127,7 +127,7 @@ bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_v
template <typename T>
void LUCpuKernelMod::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
const T &value) {
T *per_lu_value = lu_value + per_value[i] * lu_col_ + j;
T *per_lu_value = lu_value + per_value[i] * SizeToInt(lu_col_) + SizeToInt(j);
*per_lu_value = value;
}
@ -233,7 +233,7 @@ bool LUCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
for (size_t i = 0; i < pivots_col_; ++i) {
int position = per_value[i];
int *per_addr = permutation_value + position * permutation_row_ + i;
int *per_addr = permutation_value + position * SizeToInt(permutation_row_) + SizeToInt(i);
*per_addr = 1;
}
}

View File

@ -21,7 +21,6 @@
#include <utility>
#include "utils/ms_utils.h"
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
#include "Eigen/Dense"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
@ -103,11 +102,6 @@ bool LUSolverCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
} else {
MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is invalid: " << trans_;
}
if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) {
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
return true;
}

View File

@ -22,7 +22,6 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kAMatrixDimNum = 2;
constexpr size_t kQRInputsNum = 1;

View File

@ -46,6 +46,7 @@ class ScatterUpdateCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
virtual void *ScatterUpdateRealData(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) = 0;