Add ops warpper and test, add parralel launch

This commit is contained in:
wenbean 2022-01-22 16:16:34 +08:00
parent 5e6d6ee63e
commit 860f52e45d
3 changed files with 47 additions and 57 deletions

View File

@ -51,15 +51,20 @@ bool MatrixBandPartCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs
memcpy_s(out_value, matrix_size_ * sizeof(T), in_value, matrix_size_ * sizeof(T));
return true;
}
for (size_t i = 0; i < out_range_size_; i++) {
for (size_t j = 0; j < std::min(m_, l + n_); j++) {
const size_t s = j < l ? 0 : j - l;
// When i = n - u, end is n -1, because end pos is start from 0
const size_t e = j >= n_ - u ? n_ - 1 : j + u;
const size_t offset = i * m_ * n_ + j * n_;
memcpy_s(out_value + offset + s, matrix_size_ * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
}
}
size_t diag_len = std::min(m_, l + n_);
CPUKernelUtils::ParallelFor(
[matrix_size = matrix_size_, m = m_, n = n_, diag_len, l, u, in_value, out_value](size_t spos, size_t epos) {
for (size_t t = spos; t < epos; t++) {
const size_t i = t / diag_len;
const size_t j = t % diag_len;
const size_t s = j < l ? 0 : j - l;
// When i = n - u, end is n -1, because end pos is start from 0
const size_t e = j >= n - u ? n - 1 : j + u;
const size_t offset = i * m * n + j * n;
(void)memcpy_s(out_value + offset + s, matrix_size * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
}
},
out_range_size_ * diag_len);
return true;
}
} // namespace kernel

View File

@ -60,8 +60,7 @@ bool MatrixDiagPartCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs
// New diagonal matrix m * n matrix, n dimension
int64_t max_diag_len =
std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0)));
MS_LOG(DEBUG) << "Num_diags:" << num_diags << ",max_diag_len:" << max_diag_len;
int64_t dest_inner_matrix_size = num_diags * max_diag_len;
int64_t dest_inner_matrix_len = num_diags * max_diag_len;
out_shapes_ = shapes_;
// Set dynamic shape and dtype
if (!node_wpt_.expired()) {
@ -75,33 +74,35 @@ bool MatrixDiagPartCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs
auto dtype = AnfAlgo::GetOutputDeviceDataType(node_, 0);
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {out_shapes_}, node_.get());
}
for (int64_t i = 0; i < out_range_size_; i++) {
// The j_index means current dest row index
for (int64_t j = u; j >= l; j--) {
int64_t current_diag_len = j >= 0 ? std::min(n_ - j, m_) : std::min(m_ + j, n_);
int64_t current_pad_len = max_diag_len - current_diag_len;
// Pad left by default
bool pad_left = (alignment_.first == MatrixDiag::Alignment::RIGHT && j > 0) ||
(alignment_.second == MatrixDiag::Alignment::RIGHT && j < 0);
// Set none-padding values, l means current diag col index
for (int64_t k = 0; k < max_diag_len; k++) {
CPUKernelUtils::ParallelFor(
[m = m_, n = n_, max_diag_len, u, l, in_value, out_value, dest_inner_matrix_len, padding_value,
alignment = alignment_](size_t spos, size_t epos) {
for (size_t t = spos; t < epos; t++) {
const int64_t i = t / dest_inner_matrix_len;
const int64_t j = u - (t % dest_inner_matrix_len) / max_diag_len;
const int64_t k = (t % dest_inner_matrix_len) % max_diag_len;
int64_t current_diag_len = j >= 0 ? std::min(n - j, m) : std::min(m + j, n);
int64_t current_pad_len = max_diag_len - current_diag_len;
// Pad left by default
bool pad_left = (alignment.first == MatrixDiag::Alignment::RIGHT && j > 0) ||
(alignment.second == MatrixDiag::Alignment::RIGHT && j < 0);
// Set none-padding values, l means current diag col index
// this line is for k loop, for (int64_t k = 0; k < max_diag_len; k++) {
// Source pos, k offset, only effective when pad left
int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k;
// Calculate source offset row/col offset
size_t row_index = j >= 0 ? j + k_offset : k_offset;
size_t col_index = j >= 0 ? k_offset : k_offset - j;
size_t source_offset = i * m_ * n_ + col_index * n_ + row_index;
size_t source_offset = i * m * n + col_index * n + row_index;
// If current pos need pad, then the value is pad value
bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len);
T current_pad_value = current_pad_flag ? *padding_value : *(in_value + source_offset);
int64_t j_index = u - j;
size_t dest_offset = dest_inner_matrix_size * i + j_index * max_diag_len + k;
MS_LOG(DEBUG) << "the diag j:" << j << ",k:" << k << ",k_offset:" << k_offset << ",row:" << row_index
<< ",col:" << col_index << ",j_index:" << j_index << ",current_pad_value:" << current_pad_value;
size_t dest_offset = dest_inner_matrix_len * i + j_index * max_diag_len + k;
*(out_value + dest_offset) = current_pad_value;
}
}
}
},
out_range_size_ * (u - l + 1) * max_diag_len);
return true;
}
} // namespace kernel

View File

@ -336,40 +336,24 @@ def test_matrix_set_diag(data_type):
([2], 2, 7), ([2], 7, 1), ([2], 7, 2), ([2], 7, 7), ([1, 3, 2], 1, 1), ([1, 3, 2], 1, 2),
([1, 3, 2], 1, 7), ([1, 3, 2], 2, 1), ([1, 3, 2], 2, 2), ([1, 3, 2], 2, 7), ([1, 3, 2], 7, 1),
([1, 3, 2], 7, 2), ([1, 3, 2], 7, 7)])
def test_matrix_band_part_net(band_inputs):
@pytest.mark.parametrize('dtype', [onp.int32, onp.float64])
def test_matrix_band_part_net(band_inputs, dtype):
"""
Feature: ALL TO ALL
Description: test general matrix cases for matrix_band_diag in graph mode
Expectation: the result match expected_diag_band_matrix.
"""
context.set_context(mode=context.GRAPH_MODE)
batch_shape, rows, cols = band_inputs
for dtype in [onp.int32, onp.float64]:
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
for lower in -1, 0, 1, rows - 1:
for upper in -1, 0, 1, cols - 1:
band_np = mat
if lower >= 0:
band_np = onp.triu(band_np, -lower)
if upper >= 0:
band_np = onp.tril(band_np, upper)
if batch_shape:
band_np = onp.tile(band_np, batch_shape + [1, 1])
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
match_array(band.asnumpy(), band_np)
context.set_context(mode=context.PYNATIVE_MODE)
batch_shape, rows, cols = band_inputs
for dtype in [onp.int32, onp.float64]:
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
for lower in -1, 0, 1, rows - 1:
for upper in -1, 0, 1, cols - 1:
band_np = mat
if lower >= 0:
band_np = onp.triu(band_np, -lower)
if upper >= 0:
band_np = onp.tril(band_np, upper)
if batch_shape:
band_np = onp.tile(band_np, batch_shape + [1, 1])
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
match_array(band.asnumpy(), band_np)
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
for lower in -1, 0, 1, rows - 1:
for upper in -1, 0, 1, cols - 1:
band_np = mat
if lower >= 0:
band_np = onp.triu(band_np, -lower)
if upper >= 0:
band_np = onp.tril(band_np, upper)
if batch_shape:
band_np = onp.tile(band_np, batch_shape + [1, 1])
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
match_array(band.asnumpy(), band_np)