!30545 fix cpu lu kernel codex && pclint-plus

Merge pull request !30545 from zhuzhongrui/pub_master3
This commit is contained in:
i-robot 2022-02-25 06:09:55 +00:00 committed by Gitee
commit 55899ec0c5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 11 additions and 11 deletions

View File

@ -98,10 +98,10 @@ T LUCpuKernelMod<T>::GetPermutatedValue(const T *lu_value, const std::vector<int
}
template <typename T>
bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k,
bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
size_t rows) {
T max_major_value = static_cast<T>(kZeroThreshold);
int max_major_index = 0;
size_t max_major_index = 0;
for (size_t i = k; i < rows; ++i) {
T value = GetPermutatedValue(lu_value, *per_value, i, k);
T abs_value = std::abs(value);
@ -113,7 +113,7 @@ bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *co
int per_k = per_value->at(k);
(*per_value)[k] = per_value->at(max_major_index);
(*per_value)[max_major_index] = per_k;
pivots[k] = max_major_index;
pivots[k] = SizeToInt(max_major_index);
return max_major_value != static_cast<T>(kZeroThreshold);
}
@ -145,14 +145,14 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
// init pivots
std::vector<int> per_value(pivots_col_, 0);
for (size_t i = 0; i < pivots_col_; ++i) {
per_value[i] = i;
per_value[i] = SizeToInt(i);
}
// 1. memcpy input to output, do full lu inplace.
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
int s = std::min(a_row_, a_col_);
size_t s = std::min(a_row_, a_col_);
// 2. do lu decompose inplace
for (int k = 0; k < s; ++k) {
for (size_t k = 0; k < s; ++k) {
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
if (!UpdateMajorPermutation(lu_value, &per_value, pivots, k, lu_row_)) {
continue;
@ -179,8 +179,8 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
// 3. calculate final lu by permutation list
std::unordered_map<int, std::pair<int, bool>> pivots_map;
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
pivots_map[per_value[i]] = {i, false};
for (size_t i = 0; i < lu_row_; ++i) {
pivots_map[per_value[i]] = {SizeToInt(i), false};
}
int pivots_count = 0;
for (const auto &pivot : pivots_map) {
@ -192,8 +192,8 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
continue;
}
T *lu_ori_row = lu_value + index * lu_col_;
T *lu_trans_row = lu_value + key * lu_col_;
T *lu_ori_row = lu_value + index * SizeToInt(lu_col_);
T *lu_trans_row = lu_value + key * SizeToInt(lu_col_);
// copy ori data to trans lu
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data ptr

View File

@ -37,7 +37,7 @@ class LUCpuKernelMod : public NativeCpuKernelMod {
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitInputOutputSize(const CNodePtr &kernel_node) override;
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, size_t rows);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
size_t batch_size_{1};
size_t a_row_{1};