fix cpu lu kernel codex && pclint-plus

This commit is contained in:
z00512249 2022-02-24 20:36:01 +08:00
parent 5de0d89eba
commit 0c28100825
2 changed files with 11 additions and 11 deletions

View File

@ -98,10 +98,10 @@ T LUCpuKernelMod<T>::GetPermutatedValue(const T *lu_value, const std::vector<int
} }
template <typename T> template <typename T>
bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
size_t rows) { size_t rows) {
T max_major_value = static_cast<T>(kZeroThreshold); T max_major_value = static_cast<T>(kZeroThreshold);
int max_major_index = 0; size_t max_major_index = 0;
for (size_t i = k; i < rows; ++i) { for (size_t i = k; i < rows; ++i) {
T value = GetPermutatedValue(lu_value, *per_value, i, k); T value = GetPermutatedValue(lu_value, *per_value, i, k);
T abs_value = std::abs(value); T abs_value = std::abs(value);
@ -113,7 +113,7 @@ bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *co
int per_k = per_value->at(k); int per_k = per_value->at(k);
(*per_value)[k] = per_value->at(max_major_index); (*per_value)[k] = per_value->at(max_major_index);
(*per_value)[max_major_index] = per_k; (*per_value)[max_major_index] = per_k;
pivots[k] = max_major_index; pivots[k] = SizeToInt(max_major_index);
return max_major_value != static_cast<T>(kZeroThreshold); return max_major_value != static_cast<T>(kZeroThreshold);
} }
@ -145,14 +145,14 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
// init pivots // init pivots
std::vector<int> per_value(pivots_col_, 0); std::vector<int> per_value(pivots_col_, 0);
for (size_t i = 0; i < pivots_col_; ++i) { for (size_t i = 0; i < pivots_col_; ++i) {
per_value[i] = i; per_value[i] = SizeToInt(i);
} }
// 1. memcpy input to output, do full lu inplace. // 1. memcpy input to output, do full lu inplace.
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T)); (void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
int s = std::min(a_row_, a_col_); size_t s = std::min(a_row_, a_col_);
// 2. do lu decompose inplace // 2. do lu decompose inplace
for (int k = 0; k < s; ++k) { for (size_t k = 0; k < s; ++k) {
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue. // 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
if (!UpdateMajorPermutation(lu_value, &per_value, pivots, k, lu_row_)) { if (!UpdateMajorPermutation(lu_value, &per_value, pivots, k, lu_row_)) {
continue; continue;
@ -179,8 +179,8 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
// 3. calculate final lu by permutation list // 3. calculate final lu by permutation list
std::unordered_map<int, std::pair<int, bool>> pivots_map; std::unordered_map<int, std::pair<int, bool>> pivots_map;
for (int i = 0; i < static_cast<int>(lu_row_); ++i) { for (size_t i = 0; i < lu_row_; ++i) {
pivots_map[per_value[i]] = {i, false}; pivots_map[per_value[i]] = {SizeToInt(i), false};
} }
int pivots_count = 0; int pivots_count = 0;
for (const auto &pivot : pivots_map) { for (const auto &pivot : pivots_map) {
@ -192,8 +192,8 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
continue; continue;
} }
T *lu_ori_row = lu_value + index * lu_col_; T *lu_ori_row = lu_value + index * SizeToInt(lu_col_);
T *lu_trans_row = lu_value + key * lu_col_; T *lu_trans_row = lu_value + key * SizeToInt(lu_col_);
// copy ori data to trans lu // copy ori data to trans lu
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T)); (void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data ptr // copy new data to ori data ptr

View File

@ -37,7 +37,7 @@ class LUCpuKernelMod : public NativeCpuKernelMod {
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col); void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitInputOutputSize(const CNodePtr &kernel_node) override; void InitInputOutputSize(const CNodePtr &kernel_node) override;
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j); T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, size_t rows); bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value); void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
size_t batch_size_{1}; size_t batch_size_{1};
size_t a_row_{1}; size_t a_row_{1};