Merge pull request !39554 from YijieChen/master
This commit is contained in:
i-robot 2022-08-05 09:33:49 +00:00 committed by Gitee
commit dad41b6809
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 14 additions and 12 deletions

View File

@ -162,7 +162,7 @@ bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr
if (data_num_ >= kParallelArrayNumSameShape) {
auto task = [this, &output_data, &input_data, padding_value](size_t start, size_t end) {
int64_t out_begin_index = SizeToLong(start * output_elements_in_batch_);
int64_t out_begin_index = SizeToLong(start) * output_elements_in_batch_;
for (size_t index_array = start; index_array < end; index_array++) {
for (int64_t i = 0; i < num_diags_; i++) {
int64_t offset = 0;
@ -174,8 +174,9 @@ bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
for (int64_t n = 0; n < diag_len; n++) {
output_data[LongToSize(out_begin_index + offset + n)] = input_data[LongToSize(
SizeToLong(index_array) * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)];
output_data[LongToSize(out_begin_index + offset + n)] =
input_data[SizeToLong(index_array) * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n +
row_offset];
}
const bool left_align = (offset == 0);
const int64_t padding_start = (left_align) ? diag_len : 0;

View File

@ -19,7 +19,6 @@
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
@ -53,7 +52,9 @@ void MatrixDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagV3' is not in: 'LEFT_RIGHT', "
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
}
if (align_ == "") align_ = "RIGHT_LEFT";
if (align_ == "") {
align_ = "RIGHT_LEFT";
}
} else {
align_ = "RIGHT_LEFT";
}
@ -101,7 +102,7 @@ bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
}
max_diag_len_ = diagonal_shape_[diag_rank - 1];
// k
auto *k_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
auto *k_data = static_cast<int32_t *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(k_data);
lower_diag_index_ = k_data[0];
upper_diag_index_ = lower_diag_index_;
@ -117,14 +118,14 @@ bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
MS_LOG(EXCEPTION) << "For MatrixDiagV3, lower_diag_index must be smaller than upper_diag_index,received "
<< lower_diag_index_ << " is larger than " << upper_diag_index_;
}
const int64_t num_diags = upper_diag_index_ - lower_diag_index_ + 1;
const int64_t num_diags = IntToLong(upper_diag_index_) - IntToLong(lower_diag_index_) + 1;
// num_rows
size_t num_rows_num = static_cast<size_t>(inputs[kIndexNumRow]->size / sizeof(int32_t));
if (!(num_rows_num == 1)) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_rows must have only one element, received " << num_rows_num
<< " elements. ";
}
auto *num_rows_data = reinterpret_cast<int32_t *>(inputs[kIndexNumRow]->addr);
auto *num_rows_data = static_cast<int32_t *>(inputs[kIndexNumRow]->addr);
MS_EXCEPTION_IF_NULL(num_rows_data);
num_rows_ = num_rows_data[0];
// num_cols
@ -133,7 +134,7 @@ bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_cols must have only one element, received " << num_cols_num
<< " elements. ";
}
auto *num_cols_data = reinterpret_cast<int32_t *>(inputs[kIndexNumCol]->addr);
auto *num_cols_data = static_cast<int32_t *>(inputs[kIndexNumCol]->addr);
MS_EXCEPTION_IF_NULL(num_cols_data);
num_cols_ = num_cols_data[0];
@ -178,13 +179,13 @@ bool MatrixDiagV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &i
MS_LOG(EXCEPTION) << "For MatrixDiagV3, padding_value must have only one element, received " << padding_value_num
<< " elements. ";
}
auto *padding_value_data = reinterpret_cast<T *>(inputs[kIndexPaddingValue]->addr);
auto *padding_value_data = static_cast<T *>(inputs[kIndexPaddingValue]->addr);
MS_EXCEPTION_IF_NULL(padding_value_data);
T padding_value = padding_value_data[0];
auto *diagonal_data = reinterpret_cast<T *>(inputs[0]->addr);
auto *diagonal_data = static_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(diagonal_data);
auto *output_data = reinterpret_cast<T *>(outputs[0]->addr);
auto *output_data = static_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_data);
int64_t elem = 0;
for (int64_t index_array = 0; index_array < num_batches_; index_array++) {