!29502 add matrix_set_diag of gpu backend
Merge pull request !29502 from zhuzhongrui/pub_master3
This commit is contained in:
commit
6c1ccb092e
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/matrix_set_diag_gpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "common/thread_pool.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MatrixSetDiagGpuKernelMod, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MatrixSetDiagGpuKernelMod, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MatrixSetDiagGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MatrixSetDiagGpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_SET_DIAG_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_SET_DIAG_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/gpu//gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_set_diag_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class MatrixSetDiagGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MatrixSetDiagGpuKernelMod() = default;
|
||||
~MatrixSetDiagGpuKernelMod() override = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
constexpr size_t required_input_nums = 3;
|
||||
constexpr size_t required_output_nums = 1;
|
||||
if (AnfAlgo::GetInputNum(kernel_node) != required_input_nums ||
|
||||
AnfAlgo::GetOutputTensorNum(kernel_node) != required_output_nums) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the input nums are required to [input, diagonal, "
|
||||
"k, alignment] for 3 and the output nums is require to 1.";
|
||||
}
|
||||
|
||||
// invalid alignment will throw an exception.
|
||||
auto alignment = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAlignment);
|
||||
alignment_ = GetAlignments(alignment);
|
||||
constexpr int input_index = 0;
|
||||
constexpr int diag_index = 1;
|
||||
constexpr int diag_k_index = 2;
|
||||
constexpr int output_index = 0;
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index);
|
||||
auto diag_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_index);
|
||||
auto diag_k_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_k_index);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_index);
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input_shape") ||
|
||||
CHECK_SHAPE_NULL(diag_shape, kernel_name_, "diag_shape") ||
|
||||
CHECK_SHAPE_NULL(diag_k_shape, kernel_name_, "diag_k_shape") ||
|
||||
CHECK_SHAPE_NULL(output_shape, kernel_name_, "output_shape");
|
||||
if (is_null_input_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the input shape contains empty, which is invalid, please check, it's input.";
|
||||
}
|
||||
|
||||
constexpr int temporary_2d_dim = 2;
|
||||
constexpr int temporary_1d_dim = 1;
|
||||
if (SizeToInt(input_shape.size()) < temporary_2d_dim || SizeToInt(diag_shape.size()) < temporary_1d_dim ||
|
||||
input_shape != output_shape) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of input is invalid for input shape greater than 2D, diag shape "
|
||||
"greater than 1D, input shape should equal to output shape.";
|
||||
}
|
||||
if (SizeToInt(diag_k_shape.size()) != temporary_1d_dim) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's dim should be limited to range (k[0],k[1]).";
|
||||
}
|
||||
int input_rank = SizeToInt(input_shape.size());
|
||||
for (int i = 0; i < input_rank - temporary_2d_dim; ++i) {
|
||||
outer_batch_ *= SizeToInt(input_shape.at(i));
|
||||
}
|
||||
inner_rows_ = SizeToInt(input_shape.at(input_rank - temporary_2d_dim));
|
||||
inner_cols_ = SizeToInt(input_shape.at(input_rank - temporary_1d_dim));
|
||||
|
||||
expected_num_diags_ =
|
||||
SizeToInt(diag_shape.size()) == input_rank ? SizeToInt(diag_shape.at(input_rank - temporary_2d_dim)) : 1;
|
||||
for (const auto &diag_sh : diag_shape) {
|
||||
diagonal_count_ *= diag_sh;
|
||||
}
|
||||
for (const auto &k_sh : diag_k_shape) {
|
||||
k_count_ *= k_sh;
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input = inputs.at(kDim0);
|
||||
auto diag = inputs.at(kDim1);
|
||||
constexpr int diag_k_index = 2;
|
||||
auto k = inputs.at(diag_k_index);
|
||||
auto output = outputs.at(kDim0);
|
||||
|
||||
T *input_addr = reinterpret_cast<T *>(input->addr);
|
||||
T *diag_addr = reinterpret_cast<T *>(diag->addr);
|
||||
int *diag_k_addr = reinterpret_cast<int *>(k->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(output->addr);
|
||||
std::vector<int> host_k_vec(diag_k_index, 0);
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(host_k_vec.data(), diag_k_addr, k->size, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"matrix_set_diag cuda memcopy device to host Fail");
|
||||
lower_diag_index_ = host_k_vec.at(kDim0);
|
||||
upper_diag_index_ = host_k_vec.at(kDim1);
|
||||
is_single_diag_ = (lower_diag_index_ == upper_diag_index_);
|
||||
if (lower_diag_index_ <= -inner_rows_ || lower_diag_index_ >= inner_cols_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's lower_diag_index is invalid, which must be between "
|
||||
<< -inner_rows_ << " and " << inner_cols_;
|
||||
}
|
||||
if (upper_diag_index_ <= -inner_rows_ || upper_diag_index_ >= inner_cols_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's upper_diag_index is invalid, which must be between "
|
||||
<< -inner_rows_ << " and " << inner_cols_;
|
||||
}
|
||||
if (lower_diag_index_ > upper_diag_index_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's lower_diag_index_ must less than upper_diag_index "
|
||||
<< lower_diag_index_ << " < " << upper_diag_index_;
|
||||
}
|
||||
num_diags_ = upper_diag_index_ - lower_diag_index_ + 1;
|
||||
if (lower_diag_index_ != upper_diag_index_ && expected_num_diags_ != num_diags_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's lower_diag_index and upper_diag_index are not consistent "
|
||||
"with input shape.";
|
||||
}
|
||||
max_diag_len_ =
|
||||
std::min(inner_rows_ + std::min(upper_diag_index_, 0), inner_cols_ - std::max(lower_diag_index_, 0));
|
||||
|
||||
// copy input to output first, then set diagonal value to output.
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, input_addr, input->size, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"matrix_set_diag cuda memcopy input to output Fail");
|
||||
|
||||
bool right_align_super_diagonal = (alignment_.first == MatrixDiag::RIGHT);
|
||||
bool right_align_sub_diagonal = (alignment_.second == MatrixDiag::RIGHT);
|
||||
MatrixSetDiag(outer_batch_, inner_rows_, inner_cols_, num_diags_, max_diag_len_, lower_diag_index_,
|
||||
upper_diag_index_, right_align_super_diagonal, right_align_sub_diagonal, is_single_diag_, diag_addr,
|
||||
output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.emplace_back(outer_batch_ * inner_rows_ * inner_cols_ * sizeof(T));
|
||||
input_size_list_.emplace_back(diagonal_count_ * sizeof(T));
|
||||
input_size_list_.emplace_back(k_count_ * sizeof(int));
|
||||
output_size_list_.emplace_back(outer_batch_ * inner_rows_ * inner_cols_ * sizeof(T));
|
||||
};
|
||||
int lower_diag_index_{0};
|
||||
int upper_diag_index_{0};
|
||||
int inner_rows_{0};
|
||||
int inner_cols_{0};
|
||||
int num_diags_{0};
|
||||
int expected_num_diags_{0};
|
||||
int max_diag_len_{0};
|
||||
int outer_batch_{1};
|
||||
size_t diagonal_count_{1};
|
||||
size_t k_count_{1};
|
||||
bool is_single_diag_{true};
|
||||
bool is_null_input_{true};
|
||||
// <super_matrix_diag_align, sub_matrix_diag_align>
|
||||
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_SET_DIAG_KERNEL_H_
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "matrix_set_diag_impl.cuh"
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
|
||||
__inline__ __device__ int CalDiagOffset(int d, int max_diag_len, const int inner_row, const int inner_col,
|
||||
const bool right_align_super_diagonal, const bool right_align_sub_diagonal) {
|
||||
const bool right_align = (d >= 0 && right_align_super_diagonal) || (d <= 0 && right_align_sub_diagonal);
|
||||
const int diag_len = std::min(inner_row + std::min(0, d), inner_col - std::max(0, d));
|
||||
const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixSetDiagKernel(const int outer_batch, const int inner_row, const int inner_col,
|
||||
const int num_diags, const int max_diag_len, const int lower_index,
|
||||
const int upper_index, const bool right_align_super_diagonal,
|
||||
const bool right_align_sub_diagonal, const bool is_single_diag, const T *diag_addr,
|
||||
T *output_addr) {
|
||||
int count = outer_batch * inner_row * inner_col;
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
int batch = i / (inner_row * inner_col);
|
||||
int row = (i - batch * inner_row * inner_col) / inner_col;
|
||||
int col = (i - batch * inner_row * inner_col) % inner_col;
|
||||
int d = static_cast<int>(col - row);
|
||||
if (is_single_diag) {
|
||||
if (d == upper_index) {
|
||||
output_addr[i] = diag_addr[batch * max_diag_len + col - std::max(upper_index, 0)];
|
||||
}
|
||||
} else {
|
||||
int diag_index = upper_index - d;
|
||||
int offset =
|
||||
CalDiagOffset(d, max_diag_len, inner_row, inner_col, right_align_super_diagonal, right_align_sub_diagonal);
|
||||
int index_in_diag = col - std::max(d, 0) + offset;
|
||||
if (d >= lower_index && d <= upper_index) {
|
||||
output_addr[i] = diag_addr[batch * num_diags * max_diag_len + diag_index * max_diag_len + index_in_diag];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixSetDiag(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
|
||||
const int max_diag_len, const int lower_index, const int upper_index,
|
||||
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
|
||||
const bool is_single_diag, const T *diag_addr, T *output_addr, cudaStream_t cuda_stream) {
|
||||
int count = outer_batch * inner_row * inner_col;
|
||||
MatrixSetDiagKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(
|
||||
outer_batch, inner_row, inner_col, num_diags, max_diag_len, lower_index, upper_index, right_align_super_diagonal,
|
||||
right_align_sub_diagonal, is_single_diag, diag_addr, output_addr);
|
||||
return;
|
||||
}
|
||||
|
||||
template void MatrixSetDiag<int>(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
|
||||
const int max_diag_len, const int lower_index, const int upper_index,
|
||||
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
|
||||
const bool is_single_diag, const int *diag_addr, int *output_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void MatrixSetDiag<int64_t>(const int outer_batch, const int inner_row, const int inner_col,
|
||||
const int num_diags, const int max_diag_len, const int lower_index,
|
||||
const int upper_index, const bool right_align_super_diagonal,
|
||||
const bool right_align_sub_diagonal, const bool is_single_diag,
|
||||
const int64_t *diag_addr, int64_t *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
template void MatrixSetDiag<float>(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
|
||||
const int max_diag_len, const int lower_index, const int upper_index,
|
||||
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
|
||||
const bool is_single_diag, const float *diag_addr, float *output_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void MatrixSetDiag<double>(const int outer_batch, const int inner_row, const int inner_col,
|
||||
const int num_diags, const int max_diag_len, const int lower_index,
|
||||
const int upper_index, const bool right_align_super_diagonal,
|
||||
const bool right_align_sub_diagonal, const bool is_single_diag,
|
||||
const double *diag_addr, double *output_addr, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MATRIX_SET_DIAG_IMPL_CUH_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MATRIX_SET_DIAG_IMPL_CUH_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void MatrixSetDiag(const int outer_batch, const int inner_row, const int inner_col, const int num_diags,
|
||||
const int max_diag_len, const int lower_index, const int upper_index,
|
||||
const bool right_align_super_diagonal, const bool right_align_sub_diagonal,
|
||||
const bool is_single_diag, const T *diag_addr, T *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MATRIX_SET_DIAG_IMPL_CUH_
|
|
@ -48,6 +48,9 @@ static constexpr char kClean[] = "clean";
|
|||
// Used by cholesky
|
||||
static constexpr char kSplitDim[] = "split_dim";
|
||||
|
||||
// Used by MatrixSetDiag
|
||||
static constexpr char kAlignment[] = "alignment";
|
||||
|
||||
// Used by MaxPool pad: The minimum value of float32
|
||||
static constexpr float kSignedMinFloat = -3.402823466e+38F;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from .. import numpy as mnp
|
||||
from .ops import MatrixSetDiag
|
||||
from ..common import dtype as mstype
|
||||
from .utils import _to_tensor
|
||||
from .utils_const import _raise_value_error
|
||||
|
||||
|
||||
|
@ -68,5 +69,6 @@ def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"):
|
|||
k_vec = k
|
||||
else:
|
||||
_raise_value_error("input k to indicate diagonal region is invalid.")
|
||||
k_vec = _to_tensor(k_vec, dtype=mstype.int32)
|
||||
output = matrix_set_diag_net(input_x, diagonal, k_vec)
|
||||
return output
|
||||
|
|
|
@ -293,13 +293,14 @@ def fat_cases(align=None):
|
|||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [onp.int])
|
||||
def test_matrix_set_diag(data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test geneal matrix cases for matrix_set_diag in pynative mode
|
||||
Description: test geneal matrix cases for matrix_set_diag in pynative or graph mode
|
||||
Expectation: the result match expected_diag_matrix.
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
|
@ -312,20 +313,8 @@ def test_matrix_set_diag(data_type):
|
|||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_matrix(output, Tensor(expected_diag_matrix))
|
||||
match_matrix(output.astype(onp.float64), Tensor(expected_diag_matrix))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [onp.int])
|
||||
def test_graph_matrix_set_diag(data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_set_diag in graph mode
|
||||
Expectation: the result match expected_diag_matrix.
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for align in ALIGNMENT_LIST:
|
||||
for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]:
|
||||
|
@ -335,7 +324,7 @@ def test_graph_matrix_set_diag(data_type):
|
|||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_matrix(output, Tensor(expected_diag_matrix))
|
||||
match_matrix(output.astype(onp.float64), Tensor(expected_diag_matrix))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue