forked from mindspore-Ecosystem/mindspore
commit
dad41b6809
|
@ -162,7 +162,7 @@ bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr
|
|||
|
||||
if (data_num_ >= kParallelArrayNumSameShape) {
|
||||
auto task = [this, &output_data, &input_data, padding_value](size_t start, size_t end) {
|
||||
int64_t out_begin_index = SizeToLong(start * output_elements_in_batch_);
|
||||
int64_t out_begin_index = SizeToLong(start) * output_elements_in_batch_;
|
||||
for (size_t index_array = start; index_array < end; index_array++) {
|
||||
for (int64_t i = 0; i < num_diags_; i++) {
|
||||
int64_t offset = 0;
|
||||
|
@ -174,8 +174,9 @@ bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr
|
|||
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
|
||||
|
||||
for (int64_t n = 0; n < diag_len; n++) {
|
||||
output_data[LongToSize(out_begin_index + offset + n)] = input_data[LongToSize(
|
||||
SizeToLong(index_array) * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)];
|
||||
output_data[LongToSize(out_begin_index + offset + n)] =
|
||||
input_data[SizeToLong(index_array) * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n +
|
||||
row_offset];
|
||||
}
|
||||
const bool left_align = (offset == 0);
|
||||
const int64_t padding_start = (left_align) ? diag_len : 0;
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
|
@ -53,7 +52,9 @@ void MatrixDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagV3' is not in: 'LEFT_RIGHT', "
|
||||
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
|
||||
}
|
||||
if (align_ == "") align_ = "RIGHT_LEFT";
|
||||
if (align_ == "") {
|
||||
align_ = "RIGHT_LEFT";
|
||||
}
|
||||
} else {
|
||||
align_ = "RIGHT_LEFT";
|
||||
}
|
||||
|
@ -101,7 +102,7 @@ bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
}
|
||||
max_diag_len_ = diagonal_shape_[diag_rank - 1];
|
||||
// k
|
||||
auto *k_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||
auto *k_data = static_cast<int32_t *>(inputs[1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(k_data);
|
||||
lower_diag_index_ = k_data[0];
|
||||
upper_diag_index_ = lower_diag_index_;
|
||||
|
@ -117,14 +118,14 @@ bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
MS_LOG(EXCEPTION) << "For MatrixDiagV3, lower_diag_index must be smaller than upper_diag_index,received "
|
||||
<< lower_diag_index_ << " is larger than " << upper_diag_index_;
|
||||
}
|
||||
const int64_t num_diags = upper_diag_index_ - lower_diag_index_ + 1;
|
||||
const int64_t num_diags = IntToLong(upper_diag_index_) - IntToLong(lower_diag_index_) + 1;
|
||||
// num_rows
|
||||
size_t num_rows_num = static_cast<size_t>(inputs[kIndexNumRow]->size / sizeof(int32_t));
|
||||
if (!(num_rows_num == 1)) {
|
||||
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_rows must have only one element, received " << num_rows_num
|
||||
<< " elements. ";
|
||||
}
|
||||
auto *num_rows_data = reinterpret_cast<int32_t *>(inputs[kIndexNumRow]->addr);
|
||||
auto *num_rows_data = static_cast<int32_t *>(inputs[kIndexNumRow]->addr);
|
||||
MS_EXCEPTION_IF_NULL(num_rows_data);
|
||||
num_rows_ = num_rows_data[0];
|
||||
// num_cols
|
||||
|
@ -133,7 +134,7 @@ bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_cols must have only one element, received " << num_cols_num
|
||||
<< " elements. ";
|
||||
}
|
||||
auto *num_cols_data = reinterpret_cast<int32_t *>(inputs[kIndexNumCol]->addr);
|
||||
auto *num_cols_data = static_cast<int32_t *>(inputs[kIndexNumCol]->addr);
|
||||
MS_EXCEPTION_IF_NULL(num_cols_data);
|
||||
num_cols_ = num_cols_data[0];
|
||||
|
||||
|
@ -178,13 +179,13 @@ bool MatrixDiagV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &i
|
|||
MS_LOG(EXCEPTION) << "For MatrixDiagV3, padding_value must have only one element, received " << padding_value_num
|
||||
<< " elements. ";
|
||||
}
|
||||
auto *padding_value_data = reinterpret_cast<T *>(inputs[kIndexPaddingValue]->addr);
|
||||
auto *padding_value_data = static_cast<T *>(inputs[kIndexPaddingValue]->addr);
|
||||
MS_EXCEPTION_IF_NULL(padding_value_data);
|
||||
T padding_value = padding_value_data[0];
|
||||
|
||||
auto *diagonal_data = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *diagonal_data = static_cast<T *>(inputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(diagonal_data);
|
||||
auto *output_data = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto *output_data = static_cast<T *>(outputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
int64_t elem = 0;
|
||||
for (int64_t index_array = 0; index_array < num_batches_; index_array++) {
|
||||
|
|
Loading…
Reference in New Issue