forked from mindspore-Ecosystem/mindspore
Add ops warpper and test, add parralel launch
This commit is contained in:
parent
5e6d6ee63e
commit
860f52e45d
|
@ -51,15 +51,20 @@ bool MatrixBandPartCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs
|
|||
memcpy_s(out_value, matrix_size_ * sizeof(T), in_value, matrix_size_ * sizeof(T));
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < out_range_size_; i++) {
|
||||
for (size_t j = 0; j < std::min(m_, l + n_); j++) {
|
||||
const size_t s = j < l ? 0 : j - l;
|
||||
// When i = n - u, end is n -1, because end pos is start from 0
|
||||
const size_t e = j >= n_ - u ? n_ - 1 : j + u;
|
||||
const size_t offset = i * m_ * n_ + j * n_;
|
||||
memcpy_s(out_value + offset + s, matrix_size_ * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
size_t diag_len = std::min(m_, l + n_);
|
||||
CPUKernelUtils::ParallelFor(
|
||||
[matrix_size = matrix_size_, m = m_, n = n_, diag_len, l, u, in_value, out_value](size_t spos, size_t epos) {
|
||||
for (size_t t = spos; t < epos; t++) {
|
||||
const size_t i = t / diag_len;
|
||||
const size_t j = t % diag_len;
|
||||
const size_t s = j < l ? 0 : j - l;
|
||||
// When i = n - u, end is n -1, because end pos is start from 0
|
||||
const size_t e = j >= n - u ? n - 1 : j + u;
|
||||
const size_t offset = i * m * n + j * n;
|
||||
(void)memcpy_s(out_value + offset + s, matrix_size * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
|
||||
}
|
||||
},
|
||||
out_range_size_ * diag_len);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -60,8 +60,7 @@ bool MatrixDiagPartCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs
|
|||
// New diagonal matrix m * n matrix, n dimension
|
||||
int64_t max_diag_len =
|
||||
std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0)));
|
||||
MS_LOG(DEBUG) << "Num_diags:" << num_diags << ",max_diag_len:" << max_diag_len;
|
||||
int64_t dest_inner_matrix_size = num_diags * max_diag_len;
|
||||
int64_t dest_inner_matrix_len = num_diags * max_diag_len;
|
||||
out_shapes_ = shapes_;
|
||||
// Set dynamic shape and dtype
|
||||
if (!node_wpt_.expired()) {
|
||||
|
@ -75,33 +74,35 @@ bool MatrixDiagPartCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs
|
|||
auto dtype = AnfAlgo::GetOutputDeviceDataType(node_, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {out_shapes_}, node_.get());
|
||||
}
|
||||
for (int64_t i = 0; i < out_range_size_; i++) {
|
||||
// The j_index means current dest row index
|
||||
for (int64_t j = u; j >= l; j--) {
|
||||
int64_t current_diag_len = j >= 0 ? std::min(n_ - j, m_) : std::min(m_ + j, n_);
|
||||
int64_t current_pad_len = max_diag_len - current_diag_len;
|
||||
// Pad left by default
|
||||
bool pad_left = (alignment_.first == MatrixDiag::Alignment::RIGHT && j > 0) ||
|
||||
(alignment_.second == MatrixDiag::Alignment::RIGHT && j < 0);
|
||||
// Set none-padding values, l means current diag col index
|
||||
for (int64_t k = 0; k < max_diag_len; k++) {
|
||||
CPUKernelUtils::ParallelFor(
|
||||
[m = m_, n = n_, max_diag_len, u, l, in_value, out_value, dest_inner_matrix_len, padding_value,
|
||||
alignment = alignment_](size_t spos, size_t epos) {
|
||||
for (size_t t = spos; t < epos; t++) {
|
||||
const int64_t i = t / dest_inner_matrix_len;
|
||||
const int64_t j = u - (t % dest_inner_matrix_len) / max_diag_len;
|
||||
const int64_t k = (t % dest_inner_matrix_len) % max_diag_len;
|
||||
int64_t current_diag_len = j >= 0 ? std::min(n - j, m) : std::min(m + j, n);
|
||||
int64_t current_pad_len = max_diag_len - current_diag_len;
|
||||
// Pad left by default
|
||||
bool pad_left = (alignment.first == MatrixDiag::Alignment::RIGHT && j > 0) ||
|
||||
(alignment.second == MatrixDiag::Alignment::RIGHT && j < 0);
|
||||
// Set none-padding values, l means current diag col index
|
||||
// this line is for k loop, for (int64_t k = 0; k < max_diag_len; k++) {
|
||||
// Source pos, k offset, only effective when pad left
|
||||
int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k;
|
||||
// Calculate source offset row/col offset
|
||||
size_t row_index = j >= 0 ? j + k_offset : k_offset;
|
||||
size_t col_index = j >= 0 ? k_offset : k_offset - j;
|
||||
size_t source_offset = i * m_ * n_ + col_index * n_ + row_index;
|
||||
size_t source_offset = i * m * n + col_index * n + row_index;
|
||||
// If current pos need pad, then the value is pad value
|
||||
bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len);
|
||||
T current_pad_value = current_pad_flag ? *padding_value : *(in_value + source_offset);
|
||||
int64_t j_index = u - j;
|
||||
size_t dest_offset = dest_inner_matrix_size * i + j_index * max_diag_len + k;
|
||||
MS_LOG(DEBUG) << "the diag j:" << j << ",k:" << k << ",k_offset:" << k_offset << ",row:" << row_index
|
||||
<< ",col:" << col_index << ",j_index:" << j_index << ",current_pad_value:" << current_pad_value;
|
||||
size_t dest_offset = dest_inner_matrix_len * i + j_index * max_diag_len + k;
|
||||
*(out_value + dest_offset) = current_pad_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
out_range_size_ * (u - l + 1) * max_diag_len);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -336,40 +336,24 @@ def test_matrix_set_diag(data_type):
|
|||
([2], 2, 7), ([2], 7, 1), ([2], 7, 2), ([2], 7, 7), ([1, 3, 2], 1, 1), ([1, 3, 2], 1, 2),
|
||||
([1, 3, 2], 1, 7), ([1, 3, 2], 2, 1), ([1, 3, 2], 2, 2), ([1, 3, 2], 2, 7), ([1, 3, 2], 7, 1),
|
||||
([1, 3, 2], 7, 2), ([1, 3, 2], 7, 7)])
|
||||
def test_matrix_band_part_net(band_inputs):
|
||||
@pytest.mark.parametrize('dtype', [onp.int32, onp.float64])
|
||||
def test_matrix_band_part_net(band_inputs, dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_band_diag in graph mode
|
||||
Expectation: the result match expected_diag_band_matrix.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
batch_shape, rows, cols = band_inputs
|
||||
for dtype in [onp.int32, onp.float64]:
|
||||
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in -1, 0, 1, rows - 1:
|
||||
for upper in -1, 0, 1, cols - 1:
|
||||
band_np = mat
|
||||
if lower >= 0:
|
||||
band_np = onp.triu(band_np, -lower)
|
||||
if upper >= 0:
|
||||
band_np = onp.tril(band_np, upper)
|
||||
if batch_shape:
|
||||
band_np = onp.tile(band_np, batch_shape + [1, 1])
|
||||
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
|
||||
match_array(band.asnumpy(), band_np)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
batch_shape, rows, cols = band_inputs
|
||||
for dtype in [onp.int32, onp.float64]:
|
||||
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in -1, 0, 1, rows - 1:
|
||||
for upper in -1, 0, 1, cols - 1:
|
||||
band_np = mat
|
||||
if lower >= 0:
|
||||
band_np = onp.triu(band_np, -lower)
|
||||
if upper >= 0:
|
||||
band_np = onp.tril(band_np, upper)
|
||||
if batch_shape:
|
||||
band_np = onp.tile(band_np, batch_shape + [1, 1])
|
||||
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
|
||||
match_array(band.asnumpy(), band_np)
|
||||
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in -1, 0, 1, rows - 1:
|
||||
for upper in -1, 0, 1, cols - 1:
|
||||
band_np = mat
|
||||
if lower >= 0:
|
||||
band_np = onp.triu(band_np, -lower)
|
||||
if upper >= 0:
|
||||
band_np = onp.tril(band_np, upper)
|
||||
if batch_shape:
|
||||
band_np = onp.tile(band_np, batch_shape + [1, 1])
|
||||
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
|
||||
match_array(band.asnumpy(), band_np)
|
||||
|
|
Loading…
Reference in New Issue