forked from mindspore-Ecosystem/mindspore
!29657 opt test matrix_set_diag test codes
Merge pull request !29657 from zhuzhongrui/pub_master
This commit is contained in:
commit
2d85475207
|
@ -64,7 +64,7 @@ void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
InitMatrixInfo(input_shape, &input_row_, &input_col_);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex);
|
||||
InitMatrixInfo(output_shape, &output_row_, &output_col_);
|
||||
// if clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
|
||||
// If clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
|
||||
if (AnfAlgo::HasNodeAttr(CLEAN, kernel_node)) {
|
||||
clean_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CLEAN);
|
||||
}
|
||||
|
|
|
@ -58,14 +58,14 @@ class CholeskyGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (AnfAlgo::HasNodeAttr(kSplitDim, kernel_node)) {
|
||||
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, kSplitDim));
|
||||
}
|
||||
// cholesky input is sys_positive_matrix and saved by col_major in gpu backend.
|
||||
// Cholesky input is sys_positive_matrix and saved by col_major in gpu backend.
|
||||
// so we reverse lower to upper, to fake transpose col_major input to row_major.
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
}
|
||||
// get CuSolver Dense matrix handler
|
||||
// Get CuSolver Dense matrix handler
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
|
||||
|
@ -83,7 +83,7 @@ class CholeskyGpuKernelMod : public NativeGpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cholesky bind cusolverDnSetStream failed");
|
||||
"Cholesky bind cusolverDnSetStream failed");
|
||||
if (!use_split_matrix_) {
|
||||
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
@ -182,14 +182,14 @@ class CholeskyGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
bool NoSplitLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
// here all addresses are malloc by cuda, so deal with them as device's address.
|
||||
// Here all addresses are malloc by cuda, so deal with them as device's address.
|
||||
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
||||
auto d_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
|
||||
// copy input data to output, cholesky inplace output in gpu backend.
|
||||
// Copy input data to output, cholesky inplace output in gpu backend.
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, input1_addr, outer_batch_ * m_ * lda_ * unit_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
|
@ -199,13 +199,13 @@ class CholeskyGpuKernelMod : public NativeGpuKernelMod {
|
|||
h_array_[i] = output_addr + i * lda_ * m_;
|
||||
}
|
||||
|
||||
// copy output's addr to d_array_addr
|
||||
// Copy output's addr to d_array_addr
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * outer_batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
|
||||
// solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
// Solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, outer_batch_),
|
||||
|
|
|
@ -49,8 +49,8 @@ class CholeskySolveGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (AnfAlgo::HasNodeAttr(kLower, kernel_node)) {
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
|
||||
}
|
||||
// gpu input is col major default, so need to change row major.
|
||||
// in order to speedup it, just change lower to upper, because of cholesky input a is triangle matrix
|
||||
// Gpu input is col major default, so need to change row major.
|
||||
// In order to speedup it, just change lower to upper, because of cholesky input a is triangle matrix
|
||||
// when input b_col is not equal to one, maybe need a normal transpose op inplace.
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
|
@ -96,7 +96,7 @@ class CholeskySolveGpuKernelMod : public NativeGpuKernelMod {
|
|||
cudaMemcpyAsync(d_b_array_addr, h_b_array_.data(), sizeof(pointer) * outer_batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy Fail");
|
||||
// only support rhs = 1
|
||||
// Only support rhs = 1
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnSpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_,
|
||||
|
@ -111,7 +111,7 @@ class CholeskySolveGpuKernelMod : public NativeGpuKernelMod {
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
|
||||
}
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// copy results from written input's matrix to output's matrix.
|
||||
// Copy results from written input's matrix to output's matrix.
|
||||
MatrixCopy(input_b_addr, output_addr, output_elements, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -13,62 +13,47 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""st for scipy.ops_wrapper."""
|
||||
import pytest
|
||||
import numpy as onp
|
||||
import pytest
|
||||
import mindspore.scipy.ops_wrapper as ops_wrapper
|
||||
from mindspore import context, Tensor
|
||||
|
||||
from tests.st.scipy_st.utils import match_matrix, match_array
|
||||
|
||||
DEFAULT_ALIGNMENT = "LEFT_LEFT"
|
||||
ALIGNMENT_LIST = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
|
||||
|
||||
|
||||
def repack_diagonals(packed_diagonals,
|
||||
diag_index,
|
||||
num_rows,
|
||||
num_cols,
|
||||
align=None):
|
||||
if align == DEFAULT_ALIGNMENT or align is None:
|
||||
return packed_diagonals
|
||||
align = align.split("_")
|
||||
d_lower, d_upper = diag_index
|
||||
batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1)
|
||||
max_diag_len = packed_diagonals.shape[-1]
|
||||
index = (slice(None),) * batch_dims
|
||||
repacked_diagonals = onp.zeros_like(packed_diagonals)
|
||||
for d_index in range(d_lower, d_upper + 1):
|
||||
diag_len = min(num_rows + min(0, d_index), num_cols - max(0, d_index))
|
||||
row_index = d_upper - d_index
|
||||
padding_len = max_diag_len - diag_len
|
||||
left_align = (d_index >= 0 and
|
||||
align[0] == "LEFT") or (d_index <= 0 and
|
||||
align[1] == "LEFT")
|
||||
extra_dim = tuple() if d_lower == d_upper else (row_index,)
|
||||
packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
|
||||
repacked_last_dim = (slice(None),) if left_align else (slice(
|
||||
padding_len, max_diag_len, 1),)
|
||||
packed_index = index + extra_dim + packed_last_dim
|
||||
repacked_index = index + extra_dim + repacked_last_dim
|
||||
|
||||
repacked_diagonals[repacked_index] = packed_diagonals[packed_index]
|
||||
return repacked_diagonals
|
||||
def pack_diagonals_in_matrix(matrix, num_rows, num_cols, alignment=None):
|
||||
if alignment == DEFAULT_ALIGNMENT or alignment is None:
|
||||
return matrix
|
||||
packed_matrix = dict()
|
||||
for diag_index, (diagonals, padded_diagonals) in matrix.items():
|
||||
align = alignment.split("_")
|
||||
d_lower, d_upper = diag_index
|
||||
batch_dims = diagonals.ndim - (2 if d_lower < d_upper else 1)
|
||||
max_diag_len = diagonals.shape[-1]
|
||||
index = (slice(None),) * batch_dims
|
||||
packed_diagonals = onp.zeros_like(diagonals)
|
||||
for d_index in range(d_lower, d_upper + 1):
|
||||
diag_len = min(num_rows + min(0, d_index), num_cols - max(0, d_index))
|
||||
row_index = d_upper - d_index
|
||||
padding_len = max_diag_len - diag_len
|
||||
left_align = (d_index >= 0 and
|
||||
align[0] == "LEFT") or (d_index <= 0 and
|
||||
align[1] == "LEFT")
|
||||
extra_dim = tuple() if d_lower == d_upper else (row_index,)
|
||||
packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
|
||||
repacked_last_dim = (slice(None),) if left_align else (slice(
|
||||
padding_len, max_diag_len, 1),)
|
||||
packed_index = index + extra_dim + packed_last_dim
|
||||
repacked_index = index + extra_dim + repacked_last_dim
|
||||
packed_diagonals[repacked_index] = diagonals[packed_index]
|
||||
packed_matrix[diag_index] = (packed_diagonals, padded_diagonals)
|
||||
return packed_matrix
|
||||
|
||||
|
||||
def repack_diagonals_in_tests(tests, num_rows, num_cols, align=None):
|
||||
# The original test cases are LEFT_LEFT aligned.
|
||||
if align == DEFAULT_ALIGNMENT or align is None:
|
||||
return tests
|
||||
new_tests = dict()
|
||||
# Loops through each case.
|
||||
for diag_index, (packed_diagonals, padded_diagonals) in tests.items():
|
||||
repacked_diagonals = repack_diagonals(
|
||||
packed_diagonals, diag_index, num_rows, num_cols, align=align)
|
||||
new_tests[diag_index] = (repacked_diagonals, padded_diagonals)
|
||||
|
||||
return new_tests
|
||||
|
||||
|
||||
def square_cases(align=None, dtype=None):
|
||||
def square_matrix(alignment=None, data_type=None):
|
||||
mat = onp.array([[[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 1],
|
||||
[3, 4, 5, 6, 7],
|
||||
|
@ -78,12 +63,12 @@ def square_cases(align=None, dtype=None):
|
|||
[5, 6, 7, 8, 9],
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 1],
|
||||
[2, 3, 4, 5, 6]]], dtype=dtype)
|
||||
[2, 3, 4, 5, 6]]], dtype=data_type)
|
||||
num_rows, num_cols = mat.shape[-2:]
|
||||
tests = dict()
|
||||
# tests[d_lower, d_upper] = packed_diagonals
|
||||
tests[-1, -1] = (onp.array([[6, 4, 1, 7],
|
||||
[5, 2, 8, 5]], dtype=dtype),
|
||||
[5, 2, 8, 5]], dtype=data_type),
|
||||
onp.array([[[0, 0, 0, 0, 0],
|
||||
[6, 0, 0, 0, 0],
|
||||
[0, 4, 0, 0, 0],
|
||||
|
@ -93,11 +78,11 @@ def square_cases(align=None, dtype=None):
|
|||
[5, 0, 0, 0, 0],
|
||||
[0, 2, 0, 0, 0],
|
||||
[0, 0, 8, 0, 0],
|
||||
[0, 0, 0, 5, 0]]], dtype=dtype))
|
||||
[0, 0, 0, 5, 0]]], dtype=data_type))
|
||||
tests[-4, -3] = (onp.array([[[8, 5],
|
||||
[4, 0]],
|
||||
[[6, 3],
|
||||
[2, 0]]], dtype=dtype),
|
||||
[2, 0]]], dtype=data_type),
|
||||
onp.array([[[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
|
@ -107,7 +92,7 @@ def square_cases(align=None, dtype=None):
|
|||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[6, 0, 0, 0, 0],
|
||||
[2, 3, 0, 0, 0]]], dtype=dtype))
|
||||
[2, 3, 0, 0, 0]]], dtype=data_type))
|
||||
tests[-2, 1] = (onp.array([[[2, 8, 6, 3, 0],
|
||||
[1, 7, 5, 2, 8],
|
||||
[6, 4, 1, 7, 0],
|
||||
|
@ -115,7 +100,7 @@ def square_cases(align=None, dtype=None):
|
|||
[[1, 7, 4, 1, 0],
|
||||
[9, 6, 3, 9, 6],
|
||||
[5, 2, 8, 5, 0],
|
||||
[1, 7, 4, 0, 0]]], dtype=dtype),
|
||||
[1, 7, 4, 0, 0]]], dtype=data_type),
|
||||
onp.array([[[1, 2, 0, 0, 0],
|
||||
[6, 7, 8, 0, 0],
|
||||
[3, 4, 5, 6, 0],
|
||||
|
@ -125,13 +110,13 @@ def square_cases(align=None, dtype=None):
|
|||
[5, 6, 7, 0, 0],
|
||||
[1, 2, 3, 4, 0],
|
||||
[0, 7, 8, 9, 1],
|
||||
[0, 0, 4, 5, 6]]], dtype=dtype))
|
||||
[0, 0, 4, 5, 6]]], dtype=data_type))
|
||||
tests[2, 4] = (onp.array([[[5, 0, 0],
|
||||
[4, 1, 0],
|
||||
[3, 9, 7]],
|
||||
[[4, 0, 0],
|
||||
[3, 9, 0],
|
||||
[2, 8, 5]]], dtype=dtype),
|
||||
[2, 8, 5]]], dtype=data_type),
|
||||
onp.array([[[0, 0, 3, 4, 5],
|
||||
[0, 0, 0, 9, 1],
|
||||
[0, 0, 0, 0, 7],
|
||||
|
@ -141,12 +126,12 @@ def square_cases(align=None, dtype=None):
|
|||
[0, 0, 0, 8, 9],
|
||||
[0, 0, 0, 0, 5],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]], dtype=dtype))
|
||||
[0, 0, 0, 0, 0]]], dtype=data_type))
|
||||
|
||||
return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align)
|
||||
return mat, pack_diagonals_in_matrix(tests, num_rows, num_cols, alignment)
|
||||
|
||||
|
||||
def tall_cases(align=None):
|
||||
def tall_matrix(alignment=None, data_type=None):
|
||||
mat = onp.array([[[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
|
@ -156,11 +141,11 @@ def tall_cases(align=None):
|
|||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[9, 8, 7]]])
|
||||
[9, 8, 7]]], dtype=data_type)
|
||||
num_rows, num_cols = mat.shape[-2:]
|
||||
tests = dict()
|
||||
tests[0, 0] = (onp.array([[1, 5, 9],
|
||||
[3, 2, 6]]),
|
||||
[3, 2, 6]], dtype=data_type),
|
||||
onp.array([[[1, 0, 0],
|
||||
[0, 5, 0],
|
||||
[0, 0, 9],
|
||||
|
@ -168,11 +153,11 @@ def tall_cases(align=None):
|
|||
[[3, 0, 0],
|
||||
[0, 2, 0],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0]]]))
|
||||
[0, 0, 0]]], dtype=data_type))
|
||||
tests[-4, -3] = (onp.array([[[9, 5],
|
||||
[6, 0]],
|
||||
[[7, 8],
|
||||
[9, 0]]]),
|
||||
[9, 0]]], dtype=data_type),
|
||||
onp.array([[[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
|
@ -182,11 +167,11 @@ def tall_cases(align=None):
|
|||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[7, 0, 0],
|
||||
[9, 8, 0]]]))
|
||||
[9, 8, 0]]], dtype=data_type))
|
||||
tests[-2, -1] = (onp.array([[[4, 8, 7],
|
||||
[7, 8, 4]],
|
||||
[[1, 5, 9],
|
||||
[4, 8, 7]]]),
|
||||
[4, 8, 7]]], dtype=data_type),
|
||||
onp.array([[[0, 0, 0],
|
||||
[4, 0, 0],
|
||||
[7, 8, 0],
|
||||
|
@ -196,7 +181,7 @@ def tall_cases(align=None):
|
|||
[1, 0, 0],
|
||||
[4, 5, 0],
|
||||
[0, 8, 9],
|
||||
[0, 0, 7]]]))
|
||||
[0, 0, 7]]], dtype=data_type))
|
||||
tests[-2, 1] = (onp.array([[[2, 6, 0],
|
||||
[1, 5, 9],
|
||||
[4, 8, 7],
|
||||
|
@ -204,7 +189,7 @@ def tall_cases(align=None):
|
|||
[[2, 3, 0],
|
||||
[3, 2, 6],
|
||||
[1, 5, 9],
|
||||
[4, 8, 7]]]),
|
||||
[4, 8, 7]]], dtype=data_type),
|
||||
onp.array([[[1, 2, 0],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
|
@ -214,11 +199,11 @@ def tall_cases(align=None):
|
|||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[0, 8, 9],
|
||||
[0, 0, 7]]]))
|
||||
[0, 0, 7]]], dtype=data_type))
|
||||
tests[1, 2] = (onp.array([[[3, 0],
|
||||
[2, 6]],
|
||||
[[1, 0],
|
||||
[2, 3]]]),
|
||||
[2, 3]]], dtype=data_type),
|
||||
onp.array([[[0, 2, 3],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0],
|
||||
|
@ -228,52 +213,52 @@ def tall_cases(align=None):
|
|||
[0, 0, 3],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]]]))
|
||||
[0, 0, 0]]], dtype=data_type))
|
||||
|
||||
return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align)
|
||||
return mat, pack_diagonals_in_matrix(tests, num_rows, num_cols, alignment)
|
||||
|
||||
|
||||
def fat_cases(align=None):
|
||||
def fat_matrix(alignment=None, data_type=None):
|
||||
mat = onp.array([[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 1, 2, 3]],
|
||||
[[4, 5, 6, 7],
|
||||
[8, 9, 1, 2],
|
||||
[3, 4, 5, 6]]])
|
||||
[3, 4, 5, 6]]], dtype=data_type)
|
||||
num_rows, num_cols = mat.shape[-2:]
|
||||
tests = dict()
|
||||
tests[2, 2] = (onp.array([[3, 8],
|
||||
[6, 2]]),
|
||||
[6, 2]], dtype=data_type),
|
||||
onp.array([[[0, 0, 3, 0],
|
||||
[0, 0, 0, 8],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 6, 0],
|
||||
[0, 0, 0, 2],
|
||||
[0, 0, 0, 0]]]))
|
||||
[0, 0, 0, 0]]], dtype=data_type))
|
||||
tests[-2, 0] = (onp.array([[[1, 6, 2],
|
||||
[5, 1, 0],
|
||||
[9, 0, 0]],
|
||||
[[4, 9, 5],
|
||||
[8, 4, 0],
|
||||
[3, 0, 0]]]),
|
||||
[3, 0, 0]]], dtype=data_type),
|
||||
onp.array([[[1, 0, 0, 0],
|
||||
[5, 6, 0, 0],
|
||||
[9, 1, 2, 0]],
|
||||
[[4, 0, 0, 0],
|
||||
[8, 9, 0, 0],
|
||||
[3, 4, 5, 0]]]))
|
||||
[3, 4, 5, 0]]], dtype=data_type))
|
||||
tests[-1, 1] = (onp.array([[[2, 7, 3],
|
||||
[1, 6, 2],
|
||||
[5, 1, 0]],
|
||||
[[5, 1, 6],
|
||||
[4, 9, 5],
|
||||
[8, 4, 0]]]),
|
||||
[8, 4, 0]]], dtype=data_type),
|
||||
onp.array([[[1, 2, 0, 0],
|
||||
[5, 6, 7, 0],
|
||||
[0, 1, 2, 3]],
|
||||
[[4, 5, 0, 0],
|
||||
[8, 9, 1, 0],
|
||||
[0, 4, 5, 6]]]))
|
||||
[0, 4, 5, 6]]], dtype=data_type))
|
||||
tests[0, 3] = (onp.array([[[4, 0, 0],
|
||||
[3, 8, 0],
|
||||
[2, 7, 3],
|
||||
|
@ -281,21 +266,21 @@ def fat_cases(align=None):
|
|||
[[7, 0, 0],
|
||||
[6, 2, 0],
|
||||
[5, 1, 6],
|
||||
[4, 9, 5]]]),
|
||||
[4, 9, 5]]], dtype=data_type),
|
||||
onp.array([[[1, 2, 3, 4],
|
||||
[0, 6, 7, 8],
|
||||
[0, 0, 2, 3]],
|
||||
[[4, 5, 6, 7],
|
||||
[0, 9, 1, 2],
|
||||
[0, 0, 5, 6]]]))
|
||||
return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align)
|
||||
[0, 0, 5, 6]]], dtype=data_type))
|
||||
return mat, pack_diagonals_in_matrix(tests, num_rows, num_cols, alignment)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [onp.int])
|
||||
@pytest.mark.parametrize('data_type', [onp.int32, onp.int64, onp.float32, onp.float64])
|
||||
def test_matrix_set_diag(data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
|
@ -305,10 +290,10 @@ def test_matrix_set_diag(data_type):
|
|||
onp.random.seed(0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for align in ALIGNMENT_LIST:
|
||||
for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]:
|
||||
for _, tests in [square_matrix(align, data_type), tall_matrix(align, data_type), fat_matrix(align, data_type)]:
|
||||
for k_vec, (diagonal, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = onp.random.randint(10, size=mask.shape)
|
||||
input_mat = onp.random.randint(10, size=mask.shape).astype(dtype=data_type)
|
||||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
|
@ -316,10 +301,10 @@ def test_matrix_set_diag(data_type):
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for align in ALIGNMENT_LIST:
|
||||
for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]:
|
||||
for _, tests in [square_matrix(align, data_type), tall_matrix(align, data_type), fat_matrix(align, data_type)]:
|
||||
for k_vec, (diagonal, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = onp.random.randint(10, size=mask.shape)
|
||||
input_mat = onp.random.randint(10, size=mask.shape).astype(dtype=data_type)
|
||||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
|
|
Loading…
Reference in New Issue