From 815e49910ec29e6ba277ac2d067d3f436f79eea0 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Mon, 17 Jan 2022 15:03:49 +0800 Subject: [PATCH] add kernel matrix_set_diag of backend cpu --- .../backend/kernel_compiler/common_utils.cc | 20 ++ .../backend/kernel_compiler/common_utils.h | 58 +++ .../backend/kernel_compiler/cpu/cpu_kernel.h | 3 +- .../cpu/eigen/cholesky_cpu_kernel.cc | 3 +- .../cpu/matrix_set_diag_cpu_kernel.cc | 200 +++++++++++ .../cpu/matrix_set_diag_cpu_kernel.h | 90 +++++ mindspore/python/mindspore/scipy/__init__.py | 2 +- mindspore/python/mindspore/scipy/ops.py | 22 ++ .../python/mindspore/scipy/ops_wrapper.py | 101 ++++++ tests/st/scipy_st/test_linalg.py | 88 ++--- tests/st/scipy_st/test_ops_wrapper.py | 337 ++++++++++++++++++ 11 files changed, 877 insertions(+), 47 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.h create mode 100644 mindspore/python/mindspore/scipy/ops_wrapper.py create mode 100644 tests/st/scipy_st/test_ops_wrapper.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index 0abafbed5fa..fe71f46c342 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -102,6 +102,26 @@ const std::unordered_map fusion_type_name_maps = { {FusionType::DROPOUT_DOMASKV3D, "DropOutDoMaskV3D"}, {FusionType::UNKNOWN_FUSION_TYPE, ""}}; +std::pair GetAlignments(const std::string &alignment) { + auto alignment_iter = MatrixDiag::AlignmentMap.find(alignment); + if (alignment_iter == MatrixDiag::AlignmentMap.end()) { + MS_LOG(EXCEPTION) << "For current kernel, input alignment is invalid: " << alignment + << ". please limit it to {RIGHT_LEFT, LEFT_RIGHT, RIGHT_RIGHT, LEFT_LEFT}"; + } + return alignment_iter->second; +} + +int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols, + const std::pair &alignment) { + bool right_align_super_diagonal = (alignment.first == MatrixDiag::RIGHT); + bool right_align_sub_diagonal = (alignment.second == MatrixDiag::RIGHT); + const bool right_align = + (diag_index >= 0 && right_align_super_diagonal) || (diag_index <= 0 && right_align_sub_diagonal); + const int diag_len = std::min(inner_rows + std::min(0, diag_index), inner_cols - std::max(0, diag_index)); + const int offset = (right_align) ? (max_diag_len - diag_len) : 0; + return offset; +} + std::string GetFusionNameByType(const kernel::FusionType &type) { auto iter = fusion_type_name_maps.find(type); if (iter == fusion_type_name_maps.end()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index b607dc51bf3..2b5547ff2cf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -46,6 +46,17 @@ constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; const std::vector support_devices = {"aicore", "aicpu", "cuda"}; +// an enum to indicate a vector or matrix alignment direction. +// real_data: [1,2,3] left_align: [1,2,3,0] right_align:[0,1,2,3] +namespace MatrixDiag { +enum Alignment { RIGHT = 0, LEFT = 1 }; +static const mindspore::HashMap> AlignmentMap{ + {"RIGHT_LEFT", {MatrixDiag::RIGHT, MatrixDiag::LEFT}}, + {"LEFT_RIGHT", {MatrixDiag::LEFT, MatrixDiag::RIGHT}}, + {"RIGHT_RIGHT", {MatrixDiag::RIGHT, MatrixDiag::RIGHT}}, + {"LEFT_LEFT", {MatrixDiag::LEFT, MatrixDiag::LEFT}}}; +} // namespace MatrixDiag + struct KernelMetaInfo { uintptr_t func_stub_; uint32_t block_dim_; @@ -72,6 +83,53 @@ class KernelMeta { std::unordered_map kernel_meta_map_; }; +class MatrixInfo { + public: + explicit MatrixInfo(size_t max_index, const std::vector &matrix_shapes) + : max_index_(max_index), shapes_(matrix_shapes) { + current_indexes_.resize(shapes_.size(), 0); + } + ~MatrixInfo() = default; + bool SetIndex(size_t start, size_t end) { + // check data from start to end whether valid. + if (start < min_index || end > max_index_ || start >= end) { + return false; + } + // initial current indexes. + int last_rank = SizeToInt(current_indexes_.size()) - 1; + for (int i = last_rank; start != 0 && i >= 0; --i) { + current_indexes_[i] = start % shapes_.at(i); + start = start / shapes_.at(i); + } + return true; + } + std::vector IndexIterator() { + if (is_first_iterator_) { + is_first_iterator_ = false; + return current_indexes_; + } + size_t last_rank = current_indexes_.size() - 1; + current_indexes_[last_rank]++; + for (size_t i = last_rank; current_indexes_.at(i) >= shapes_.at(i) && i > 0; --i) { + current_indexes_[i] = 0; + current_indexes_[i - 1] += 1; + } + is_first_iterator_ = false; + return current_indexes_; + } + + private: + bool is_first_iterator_{true}; + size_t min_index{0}; + size_t max_index_{1}; + std::vector shapes_; + std::vector current_indexes_; +}; +using MatrixInfoPtr = std::shared_ptr; + +std::pair GetAlignments(const std::string &alignment); +int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols, + const std::pair &alignment); std::string GetCompilerCachePath(); bool CheckCache(const std::string &kernel_name); KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 2f58dd153c1..afe22b39902 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -99,6 +99,7 @@ constexpr char MODE[] = "mode"; constexpr char UNIT_DIAGONAL[] = "unit_diagonal"; constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors"; constexpr char ADJOINT[] = "adjoint"; +constexpr char ALIGNMENT[] = "alignment"; struct ParallelSearchInfo { double min_cost_time{DBL_MAX}; @@ -111,7 +112,7 @@ struct ParallelSearchInfo { class CpuDynamicKernel : public device::DynamicKernel { public: explicit CpuDynamicKernel(const CNodePtr &cnode_ptr) : DynamicKernel(nullptr, cnode_ptr) {} - ~CpuDynamicKernel() = default; + ~CpuDynamicKernel() override = default; void UpdateArgs() override; void PostExecute() final { MS_LOG(EXCEPTION) << "`PostExecute()` should not invoked with cpu backend"; }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc index b2a8d4b7d76..d60520dd645 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc @@ -45,7 +45,8 @@ void CholeskyCPUKernel::InitMatrixInfo(const std::vector &shape, size *col = shape.at(shape.size() - kColIndex); } if (*row != *col) { - MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid: " << *row << ", " << *col; + MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid. " + << "cholesky expects a square matrix. but input shape is:" << *row << ", " << *col; } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.cc new file mode 100644 index 00000000000..b78bd86c447 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.h" +#include +#include +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/thread_pool.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +void MatrixSetDiagCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + constexpr size_t required_input_nums = 3; + constexpr size_t required_output_nums = 1; + if (AnfAlgo::GetInputNum(kernel_node) != required_input_nums || + AnfAlgo::GetOutputTensorNum(kernel_node) != required_output_nums) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the input nums are required to [input, diagonal, " + "k, alignment] for 3 and the output nums is require to 1."; + } + + // invalid alignment will throw an exception. + auto alignment = AnfAlgo::GetNodeAttr(kernel_node, ALIGNMENT); + alignment_ = GetAlignments(alignment); + constexpr int input_index = 0; + constexpr int diag_index = 1; + constexpr int diag_k_index = 2; + constexpr int output_index = 0; + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index); + auto diag_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_index); + auto diag_k_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_k_index); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_index); + + constexpr int temporary_2d_dim = 2; + constexpr int temporary_1d_dim = 1; + if (SizeToInt(input_shape.size()) < temporary_2d_dim || SizeToInt(diag_shape.size()) < temporary_1d_dim || + input_shape != output_shape) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dimension of input is invalid for input shape greater than 2D, diag shape " + "greater than 1D, input shape should equal to output shape."; + } + if (SizeToInt(diag_k_shape.size()) != temporary_1d_dim) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dimension of diag_region's dim should be limited to range (k[0],k[1])."; + } + int input_rank = SizeToInt(input_shape.size()); + for (int i = 0; i < input_rank - temporary_2d_dim; ++i) { + outer_batch_ *= SizeToInt(input_shape.at(i)); + } + input_shape_ = input_shape; + inner_rows_ = SizeToInt(input_shape.at(input_rank - temporary_2d_dim)); + inner_cols_ = SizeToInt(input_shape.at(input_rank - temporary_1d_dim)); + + expected_num_diags_ = + SizeToInt(diag_shape.size()) == input_rank ? SizeToInt(diag_shape.at(input_rank - temporary_2d_dim)) : 1; + + data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); +} + +bool MatrixSetDiagCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspaces, + const std::vector &outputs) { + if (data_type_ == kNumberTypeFloat16) { + LaunchKernel(inputs, workspaces, outputs); + } else if (data_type_ == kNumberTypeFloat32) { + LaunchKernel(inputs, workspaces, outputs); + } else if (data_type_ == kNumberTypeFloat64) { + LaunchKernel(inputs, workspaces, outputs); + } else if (data_type_ == kNumberTypeInt32) { + LaunchKernel(inputs, workspaces, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the data_type of input should be float16, float32, float64, int, int32." + " but got " + << TypeIdToType(data_type_)->ToString(); + } + return true; +} + +template +void MatrixSetDiagCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &workspaces, + const std::vector &outputs) { + auto input = inputs.at(0); + auto diag = inputs.at(1); + constexpr int diag_k_index = 2; + auto k = inputs.at(diag_k_index); + auto output = outputs.at(0); + + T *input_addr = reinterpret_cast(input->addr); + T *diag_addr = reinterpret_cast(diag->addr); + int *diag_k_addr = reinterpret_cast(k->addr); + T *output_addr = reinterpret_cast(output->addr); + lower_diag_index_ = diag_k_addr[0]; + upper_diag_index_ = diag_k_addr[1]; + is_single_diag_ = (lower_diag_index_ == upper_diag_index_); + if (lower_diag_index_ <= -inner_rows_ || lower_diag_index_ >= inner_cols_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dimension of diag_region's lower_diag_index is invalid, which must be between " + << -inner_rows_ << " and " << inner_cols_; + } + if (upper_diag_index_ <= -inner_rows_ || upper_diag_index_ >= inner_cols_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dimension of diag_region's upper_diag_index is invalid, which must be between " + << -inner_rows_ << " and " << inner_cols_; + } + + if (lower_diag_index_ > upper_diag_index_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dimension of diag_region's lower_diag_index_ must less than upper_diag_index " + << lower_diag_index_ << " < " << upper_diag_index_; + } + num_diags_ = upper_diag_index_ - lower_diag_index_ + 1; + if (lower_diag_index_ != upper_diag_index_ && expected_num_diags_ != num_diags_) { + MS_LOG(EXCEPTION) + << "For '" << kernel_name_ + << "', the dimension of diag_region's lower_diag_index and upper_diag_index are not consistent with input shape."; + } + max_diag_len_ = std::min(inner_rows_ + std::min(upper_diag_index_, 0), inner_cols_ - std::max(lower_diag_index_, 0)); + + // copy input to output first, then set diagonal value to output. + (void)memcpy_s(output_addr, output->size, input_addr, input->size); + std::vector tasks; + // an arg which depends on hardware. + auto thread_pool = GetActorMgrInnerThreadPool(); + size_t task_nums = thread_pool->GetKernelThreadNum(); + if (task_nums == 0) { + MS_LOG(EXCEPTION) << "MatrixSetDiagCPUKernel get kernel thread_pool nums, but kernel thread is 0!"; + } + tasks.reserve(task_nums); + size_t max_index = IntToSize(outer_batch_ * inner_rows_ * inner_cols_); + size_t region = IntToSize(std::div(SizeToInt(max_index), SizeToInt(task_nums)).quot); + region = (region == 0) ? max_index : region; + for (size_t start = 0; start < max_index; start += region) { + size_t end = start + region; + if (end > max_index) { + end = max_index; + } + (void)tasks.emplace_back([this, max_index, start, end, diag_addr, output_addr]() { + MatrixInfoPtr matrix_info = std::make_shared(max_index, input_shape_); + if (!matrix_info->SetIndex(start, end)) { + MS_LOG(EXCEPTION) << "current data indexes are invalid : [" << start << ", " << end + << "]. you should limit them in [0, " << max_index << "]."; + } + auto get_out_batch = [](const std::vector ¤t_indexes) { + constexpr size_t last_two_dims = 2; + int out_batch = 1; + for (size_t i = 0; i < current_indexes.size() - last_two_dims; ++i) { + out_batch *= (SizeToInt(current_indexes.at(i)) + 1); + } + size_t inner_row = current_indexes.at(current_indexes.size() - last_two_dims); + size_t inner_col = current_indexes.at(current_indexes.size() - 1); + std::tuple flatten_3d_shape = std::make_tuple(out_batch - 1, inner_row, inner_col); + return flatten_3d_shape; + }; + for (size_t inner = start; inner < end; ++inner) { + std::vector current_indexes = matrix_info->IndexIterator(); + auto flatten_3d_shape = get_out_batch(current_indexes); + int batch = std::get<0>(flatten_3d_shape); + int m = std::get<1>(flatten_3d_shape); + constexpr size_t col_index = 2; + int n = std::get(flatten_3d_shape); + int d = n - m; + if (is_single_diag_) { + if (d == upper_diag_index_) { + output_addr[inner] = diag_addr[batch * max_diag_len_ + n - std::max(upper_diag_index_, 0)]; + } + } else { + int diag_index = upper_diag_index_ - d; + int offset = CalDiagOffset(d, max_diag_len_, inner_rows_, inner_cols_, alignment_); + int index_in_diag = n - std::max(d, 0) + offset; + if (d >= lower_diag_index_ && d <= upper_diag_index_) { + output_addr[inner] = + diag_addr[batch * num_diags_ * max_diag_len_ + diag_index * max_diag_len_ + index_in_diag]; + } + } + } + return common::SUCCESS; + }); + } + ParallelLaunch(tasks); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.h new file mode 100644 index 00000000000..85dbc4fac7f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.h @@ -0,0 +1,90 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/common_utils.h" +namespace mindspore { +namespace kernel { +class MatrixSetDiagCPUKernel : public CPUKernel { + public: + MatrixSetDiagCPUKernel() = default; + ~MatrixSetDiagCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs) override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs); + int lower_diag_index_{0}; + int upper_diag_index_{0}; + int inner_rows_{0}; + int inner_cols_{0}; + int num_diags_{0}; + int expected_num_diags_{0}; + int max_diag_len_{0}; + int outer_batch_{1}; + bool is_single_diag_{true}; + std::vector input_shape_; + // + std::pair alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT}; + TypeId data_type_{0}; +}; + +MS_REG_CPU_KERNEL(MatrixSetDiag, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + MatrixSetDiagCPUKernel) + +MS_REG_CPU_KERNEL(MatrixSetDiag, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + MatrixSetDiagCPUKernel) + +MS_REG_CPU_KERNEL(MatrixSetDiag, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + MatrixSetDiagCPUKernel) + +MS_REG_CPU_KERNEL(MatrixSetDiag, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + MatrixSetDiagCPUKernel) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_KERNEL_H_ diff --git a/mindspore/python/mindspore/scipy/__init__.py b/mindspore/python/mindspore/scipy/__init__.py index 128b2db5a36..c90ed38f682 100644 --- a/mindspore/python/mindspore/scipy/__init__.py +++ b/mindspore/python/mindspore/scipy/__init__.py @@ -14,4 +14,4 @@ # ============================================================================ """Scipy-like interfaces in mindspore.""" -from . import linalg, optimize, sparse +from . import linalg, optimize, sparse, ops_wrapper diff --git a/mindspore/python/mindspore/scipy/ops.py b/mindspore/python/mindspore/scipy/ops.py index 58d086ea443..5905653eef6 100644 --- a/mindspore/python/mindspore/scipy/ops.py +++ b/mindspore/python/mindspore/scipy/ops.py @@ -332,3 +332,25 @@ class LUSolver(PrimitiveWithInfer): 'value': None } return output + + +class MatrixSetDiag(PrimitiveWithInfer): + """ + Inner API to set a [..., M, N] matrix's diagonals by range[k[0], k[1]]. + """ + + @prim_attr_register + def __init__(self, alignment: str): + super().__init__(name="MatrixSetDiag") + self.init_prim_io_names(inputs=['input_x', 'diagonal', 'k'], outputs=['output']) + self.alignment = validator.check_value_type("alignment", alignment, [str], self.name) + + def __infer__(self, input_x, diagonal, k): + in_shape = list(input_x['shape']) + in_dtype = input_x['dtype'] + output = { + 'shape': tuple(in_shape), + 'dtype': in_dtype, + 'value': None + } + return output diff --git a/mindspore/python/mindspore/scipy/ops_wrapper.py b/mindspore/python/mindspore/scipy/ops_wrapper.py new file mode 100644 index 00000000000..6137fb5655c --- /dev/null +++ b/mindspore/python/mindspore/scipy/ops_wrapper.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Linear algebra submodule""" +from .. import numpy as mnp +from .ops import MatrixSetDiag +from ..common import dtype as mstype +from .utils_const import _raise_value_error + + +def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"): + """ + Returns a batched matrix tensor with new batched diagonal values. + Given `input` and `diagonal`, this operation returns a tensor with the same shape and values as `input`, + except for the specified diagonals of the innermost matrices. These will be overwritten by the values in `diagonal`. + `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or `k[0] == k[1]`, + `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`. Otherwise, it has `r+1` dimensions + `[I, J, ..., L, num_diags, max_diag_len]`. `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`. + `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`, + `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` The output is a tensor of rank `k+1` with + dimensions `[I, J, ..., L, M, N]`. If `k` is scalar or `k[0] == k[1]`: + ``` + output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1] + input[i, j, ..., l, m, n] ; otherwise + ``` + Otherwise, + ``` + output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + input[i, j, ..., l, m, n] ; otherwise + ``` + where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0) + offset`. + `offset` is zero except when the alignment of the diagonal is to the right. + ``` + offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise + ``` + where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + + Args: + input_x (Tensor): a :math:`(..., M, N)` matrix to be set diag. + diagonal (Tensor): a :math`(..., max_diag_len)`, or `(..., num_diags, max_diag_len)` vector to be placed to + output's diags. + k (Tensor): a scalar or 1D list. it's max length is to which indicates the diag's lower index and upper index. + (k[0], k[1]). + alignment (str): Some diagonals are shorter than `max_diag_len` and need to be padded. + `align` is a string specifying how superdiagonals and subdiagonals should be aligned, + respectively. There are four possible alignments: "RIGHT_LEFT" (default), + "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals to + the right (left-pads the row) and subdiagonals to the left (right-pads the row). + + Returns: + - Tensor, :math:`(...,M, N)`. a batched matrix with the same shape and values as `input`, + except for the specified diagonals of the innermost matrices. + + Supported Platforms: + ``CPU`` ``GPU`` + + Examples: + >>> import numpy as onp + >>> from mindspore.common import Tensor + >>> from mindspore.scipy.ops_wrapper import matrix_set_diag + >>> input_x = Tensor( + >>> onp.array([[[7, 7, 7, 7],[7, 7, 7, 7], [7, 7, 7, 7]], + >>> [[7, 7, 7, 7],[7, 7, 7, 7],[7, 7, 7, 7]]])).astype(onp.int) + >>> diagonal = Tensor(onp.array([[1, 2, 3],[4, 5, 6]])).astype(onp.int) + >>> output = matrix_set_diag(input_x, diagonal) + >>> print(output) + >>> [[[1 7 7 7] + [7 2 7 7] + [7 7 3 7]] + + [[4 7 7 7] + [7 5 7 7] + [7 7 6 7]] + """ + matrix_set_diag_net = MatrixSetDiag(alignment) + k_vec = mnp.zeros((2,), dtype=mstype.int32) + if isinstance(k, int): + k_vec += k + elif isinstance(k, (list, tuple)): + k_vec = k + else: + _raise_value_error("input k to indicate diagonal region is invalid.") + output = matrix_set_diag_net(input_x, diagonal, k_vec) + return output diff --git a/tests/st/scipy_st/test_linalg.py b/tests/st/scipy_st/test_linalg.py index cd5ef702cdc..e3b6af1cdea 100644 --- a/tests/st/scipy_st/test_linalg.py +++ b/tests/st/scipy_st/test_linalg.py @@ -54,16 +54,16 @@ def test_block_diag(args): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -@pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) +@pytest.mark.parametrize('data_type', [onp.float32, onp.float64]) @pytest.mark.parametrize('shape', [(4, 4), (50, 50), (2, 5, 5)]) -def test_inv(dtype, shape): +def test_inv(data_type, shape): """ Feature: ALL TO ALL Description: test cases for inv Expectation: the result match numpy """ onp.random.seed(0) - x = create_full_rank_matrix(shape, dtype) + x = create_full_rank_matrix(shape, data_type) ms_res = msp.linalg.inv(Tensor(x)) scipy_res = onp.linalg.inv(x) @@ -76,14 +76,14 @@ def test_inv(dtype, shape): @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 5, 6]) @pytest.mark.parametrize('lower', [True, False]) -@pytest.mark.parametrize('dtype', [onp.float64]) -def test_cholesky(n: int, lower: bool, dtype: Generic): +@pytest.mark.parametrize('data_type', [onp.float64]) +def test_cholesky(n: int, lower: bool, data_type: Generic): """ Feature: ALL TO ALL Description: test cases for cholesky [N,N] Expectation: the result match scipy cholesky """ - a = create_sym_pos_matrix((n, n), dtype) + a = create_sym_pos_matrix((n, n), data_type) tensor_a = Tensor(a) rtol = 1.e-5 atol = 1.e-8 @@ -98,14 +98,14 @@ def test_cholesky(n: int, lower: bool, dtype: Generic): @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 5, 6]) @pytest.mark.parametrize('lower', [True, False]) -@pytest.mark.parametrize('dtype', [onp.float64]) -def test_cho_factor(n: int, lower: bool, dtype: Generic): +@pytest.mark.parametrize('data_type', [onp.float64]) +def test_cho_factor(n: int, lower: bool, data_type: Generic): """ Feature: ALL TO ALL Description: test cases for cholesky [N,N] Expectation: the result match scipy cholesky """ - a = create_sym_pos_matrix((n, n), dtype) + a = create_sym_pos_matrix((n, n), data_type) tensor_a = Tensor(a) msp_c, _ = msp.linalg.cho_factor(tensor_a, lower=lower) if lower: @@ -121,15 +121,15 @@ def test_cho_factor(n: int, lower: bool, dtype: Generic): @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 5, 6]) @pytest.mark.parametrize('lower', [True, False]) -@pytest.mark.parametrize('dtype', [onp.float64]) -def test_cholesky_solver(n: int, lower: bool, dtype): +@pytest.mark.parametrize('data_type', [onp.float64]) +def test_cholesky_solver(n: int, lower: bool, data_type): """ Feature: ALL TO ALL Description: test cases for cholesky solver [N,N] Expectation: the result match scipy cholesky_solve """ - a = create_sym_pos_matrix((n, n), dtype) - b = onp.ones((n, 1), dtype=dtype) + a = create_sym_pos_matrix((n, n), data_type) + b = onp.ones((n, 1), dtype=data_type) tensor_a = Tensor(a) tensor_b = Tensor(b) osp_c, lower = osp.linalg.cho_factor(a, lower=lower) @@ -149,10 +149,10 @@ def test_cholesky_solver(n: int, lower: bool, dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 6, 9, 20]) -@pytest.mark.parametrize('dtype', +@pytest.mark.parametrize('data_type', [(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"), (onp.float64, "d")]) -def test_eigh(n: int, dtype): +def test_eigh(n: int, data_type): """ Feature: ALL TO ALL Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N] @@ -160,11 +160,11 @@ def test_eigh(n: int, dtype): """ # test for real scalar float tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)} - rtol = tol[dtype[1]][0] - atol = tol[dtype[1]][1] - A = create_sym_pos_matrix([n, n], dtype[0]) - msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=False) - msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=False) + rtol = tol[data_type[1]][0] + atol = tol[data_type[1]][1] + A = create_sym_pos_matrix([n, n], data_type[0]) + msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=True, eigvals_only=False) + msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=False, eigvals_only=False) assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), rtol, atol) @@ -172,8 +172,8 @@ def test_eigh(n: int, dtype): rtol, atol) # test for real scalar float no vector - msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=True) - msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=True) + msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=True, eigvals_only=True) + msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=False, eigvals_only=True) assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol) assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol) @@ -183,8 +183,8 @@ def test_eigh(n: int, dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 6, 9, 20]) -@pytest.mark.parametrize('dtype', [(onp.complex64, "f"), (onp.complex128, "d")]) -def test_eigh_complex(n: int, dtype): +@pytest.mark.parametrize('data_type', [(onp.complex64, "f"), (onp.complex128, "d")]) +def test_eigh_complex(n: int, data_type): """ Feature: ALL TO ALL Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N] @@ -192,9 +192,9 @@ def test_eigh_complex(n: int, dtype): """ # test case for complex tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)} - rtol = tol[dtype[1]][0] - atol = tol[dtype[1]][1] - A = onp.array(onp.random.rand(n, n), dtype=dtype[0]) + rtol = tol[data_type[1]][0] + atol = tol[data_type[1]][1] + A = onp.array(onp.random.rand(n, n), dtype=data_type[0]) for i in range(0, n): for j in range(0, n): if i == j: @@ -203,16 +203,16 @@ def test_eigh_complex(n: int, dtype): A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1)) sym_al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T) sym_au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T) - msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(dtype[0])), lower=True, eigvals_only=False) - msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(dtype[0])), lower=False, eigvals_only=False) + msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(data_type[0])), lower=True, eigvals_only=False) + msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(data_type[0])), lower=False, eigvals_only=False) assert onp.allclose(sym_al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), rtol, atol) assert onp.allclose(sym_au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)), rtol, atol) # test for real scalar complex no vector - msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(dtype[0])), lower=True, eigvals_only=True) - msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(dtype[0])), lower=False, eigvals_only=True) + msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(data_type[0])), lower=True, eigvals_only=True) + msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(data_type[0])), lower=False, eigvals_only=True) assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol) assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol) @@ -222,14 +222,14 @@ def test_eigh_complex(n: int, dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('shape', [(4, 4), (4, 5), (5, 10), (20, 20)]) -@pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) -def test_lu(shape: (int, int), dtype): +@pytest.mark.parametrize('data_type', [onp.float32, onp.float64]) +def test_lu(shape: (int, int), data_type): """ Feature: ALL To ALL Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] Expectation: the result match to scipy """ - a = create_random_rank_matrix(shape, dtype) + a = create_random_rank_matrix(shape, data_type) s_p, s_l, s_u = osp.linalg.lu(a) tensor_a = Tensor(a) m_p, m_l, m_u = msp.linalg.lu(tensor_a) @@ -245,14 +245,14 @@ def test_lu(shape: (int, int), dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)]) -@pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) -def test_batch_lu(shape, dtype): +@pytest.mark.parametrize('data_type', [onp.float32, onp.float64]) +def test_batch_lu(shape, data_type): """ Feature: ALL To ALL Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] Expectation: the result match to scipy """ - b_a = create_random_rank_matrix(shape, dtype) + b_a = create_random_rank_matrix(shape, data_type) b_s_p = list() b_s_l = list() b_s_u = list() @@ -279,14 +279,14 @@ def test_batch_lu(shape, dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 5, 10, 20]) -@pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) -def test_lu_factor(n: int, dtype): +@pytest.mark.parametrize('data_type', [onp.float32, onp.float64]) +def test_lu_factor(n: int, data_type): """ Feature: ALL To ALL Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] Expectation: the result match to scipy """ - a = create_full_rank_matrix((n, n), dtype) + a = create_full_rank_matrix((n, n), data_type) s_lu, s_pivots = osp.linalg.lu_factor(a) tensor_a = Tensor(a) m_lu, m_pivots = msp.linalg.lu_factor(tensor_a) @@ -300,15 +300,15 @@ def test_lu_factor(n: int, dtype): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.parametrize('n', [4, 5, 10, 20]) -@pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) -def test_lu_solve(n: int, dtype): +@pytest.mark.parametrize('data_type', [onp.float32, onp.float64]) +def test_lu_solve(n: int, data_type): """ Feature: ALL To ALL Description: test cases for lu_solve test cases for A[N,N]x = b[N,1] Expectation: the result match to scipy """ - a = create_full_rank_matrix((n, n), dtype) - b = onp.random.random((n, 1)).astype(dtype) + a = create_full_rank_matrix((n, n), data_type) + b = onp.random.random((n, 1)).astype(data_type) s_lu, s_piv = osp.linalg.lu_factor(a) tensor_a = Tensor(a) diff --git a/tests/st/scipy_st/test_ops_wrapper.py b/tests/st/scipy_st/test_ops_wrapper.py new file mode 100644 index 00000000000..91f72d148e5 --- /dev/null +++ b/tests/st/scipy_st/test_ops_wrapper.py @@ -0,0 +1,337 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""st for scipy.ops_wrapper.""" +import pytest +import numpy as onp +import mindspore.scipy as msp +from mindspore import context, Tensor +from tests.st.scipy_st.utils import match_array + +DEFAULT_ALIGNMENT = "LEFT_LEFT" +ALIGNMENT_LIST = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"] + + +def repack_diagonals(packed_diagonals, + diag_index, + num_rows, + num_cols, + align=None): + if align == DEFAULT_ALIGNMENT or align is None: + return packed_diagonals + align = align.split("_") + d_lower, d_upper = diag_index + batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1) + max_diag_len = packed_diagonals.shape[-1] + index = (slice(None),) * batch_dims + repacked_diagonals = onp.zeros_like(packed_diagonals) + for d_index in range(d_lower, d_upper + 1): + diag_len = min(num_rows + min(0, d_index), num_cols - max(0, d_index)) + row_index = d_upper - d_index + padding_len = max_diag_len - diag_len + left_align = (d_index >= 0 and + align[0] == "LEFT") or (d_index <= 0 and + align[1] == "LEFT") + extra_dim = tuple() if d_lower == d_upper else (row_index,) + packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),) + repacked_last_dim = (slice(None),) if left_align else (slice( + padding_len, max_diag_len, 1),) + packed_index = index + extra_dim + packed_last_dim + repacked_index = index + extra_dim + repacked_last_dim + + repacked_diagonals[repacked_index] = packed_diagonals[packed_index] + return repacked_diagonals + + +def repack_diagonals_in_tests(tests, num_rows, num_cols, align=None): + # The original test cases are LEFT_LEFT aligned. + if align == DEFAULT_ALIGNMENT or align is None: + return tests + new_tests = dict() + # Loops through each case. + for diag_index, (packed_diagonals, padded_diagonals) in tests.items(): + repacked_diagonals = repack_diagonals( + packed_diagonals, diag_index, num_rows, num_cols, align=align) + new_tests[diag_index] = (repacked_diagonals, padded_diagonals) + + return new_tests + + +def square_cases(align=None, dtype=None): + mat = onp.array([[[1, 2, 3, 4, 5], + [6, 7, 8, 9, 1], + [3, 4, 5, 6, 7], + [8, 9, 1, 2, 3], + [4, 5, 6, 7, 8]], + [[9, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 1], + [2, 3, 4, 5, 6]]], dtype=dtype) + num_rows, num_cols = mat.shape[-2:] + tests = dict() + # tests[d_lower, d_upper] = packed_diagonals + tests[-1, -1] = (onp.array([[6, 4, 1, 7], + [5, 2, 8, 5]], dtype=dtype), + onp.array([[[0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 7, 0]], + [[0, 0, 0, 0, 0], + [5, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + [0, 0, 8, 0, 0], + [0, 0, 0, 5, 0]]], dtype=dtype)) + tests[-4, -3] = (onp.array([[[8, 5], + [4, 0]], + [[6, 3], + [2, 0]]], dtype=dtype), + onp.array([[[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [8, 0, 0, 0, 0], + [4, 5, 0, 0, 0]], + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [2, 3, 0, 0, 0]]], dtype=dtype)) + tests[-2, 1] = (onp.array([[[2, 8, 6, 3, 0], + [1, 7, 5, 2, 8], + [6, 4, 1, 7, 0], + [3, 9, 6, 0, 0]], + [[1, 7, 4, 1, 0], + [9, 6, 3, 9, 6], + [5, 2, 8, 5, 0], + [1, 7, 4, 0, 0]]], dtype=dtype), + onp.array([[[1, 2, 0, 0, 0], + [6, 7, 8, 0, 0], + [3, 4, 5, 6, 0], + [0, 9, 1, 2, 3], + [0, 0, 6, 7, 8]], + [[9, 1, 0, 0, 0], + [5, 6, 7, 0, 0], + [1, 2, 3, 4, 0], + [0, 7, 8, 9, 1], + [0, 0, 4, 5, 6]]], dtype=dtype)) + tests[2, 4] = (onp.array([[[5, 0, 0], + [4, 1, 0], + [3, 9, 7]], + [[4, 0, 0], + [3, 9, 0], + [2, 8, 5]]], dtype=dtype), + onp.array([[[0, 0, 3, 4, 5], + [0, 0, 0, 9, 1], + [0, 0, 0, 0, 7], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 0, 2, 3, 4], + [0, 0, 0, 8, 9], + [0, 0, 0, 0, 5], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]], dtype=dtype)) + + return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align) + + +def tall_cases(align=None): + mat = onp.array([[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [9, 8, 7], + [6, 5, 4]], + [[3, 2, 1], + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [9, 8, 7]]]) + num_rows, num_cols = mat.shape[-2:] + tests = dict() + tests[0, 0] = (onp.array([[1, 5, 9], + [3, 2, 6]]), + onp.array([[[1, 0, 0], + [0, 5, 0], + [0, 0, 9], + [0, 0, 0]], + [[3, 0, 0], + [0, 2, 0], + [0, 0, 6], + [0, 0, 0]]])) + tests[-4, -3] = (onp.array([[[9, 5], + [6, 0]], + [[7, 8], + [9, 0]]]), + onp.array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [9, 0, 0], + [6, 5, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [7, 0, 0], + [9, 8, 0]]])) + tests[-2, -1] = (onp.array([[[4, 8, 7], + [7, 8, 4]], + [[1, 5, 9], + [4, 8, 7]]]), + onp.array([[[0, 0, 0], + [4, 0, 0], + [7, 8, 0], + [0, 8, 7], + [0, 0, 4]], + [[0, 0, 0], + [1, 0, 0], + [4, 5, 0], + [0, 8, 9], + [0, 0, 7]]])) + tests[-2, 1] = (onp.array([[[2, 6, 0], + [1, 5, 9], + [4, 8, 7], + [7, 8, 4]], + [[2, 3, 0], + [3, 2, 6], + [1, 5, 9], + [4, 8, 7]]]), + onp.array([[[1, 2, 0], + [4, 5, 6], + [7, 8, 9], + [0, 8, 7], + [0, 0, 4]], + [[3, 2, 0], + [1, 2, 3], + [4, 5, 6], + [0, 8, 9], + [0, 0, 7]]])) + tests[1, 2] = (onp.array([[[3, 0], + [2, 6]], + [[1, 0], + [2, 3]]]), + onp.array([[[0, 2, 3], + [0, 0, 6], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 2, 1], + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]])) + + return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align) + + +def fat_cases(align=None): + mat = onp.array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 1, 2, 3]], + [[4, 5, 6, 7], + [8, 9, 1, 2], + [3, 4, 5, 6]]]) + num_rows, num_cols = mat.shape[-2:] + tests = dict() + tests[2, 2] = (onp.array([[3, 8], + [6, 2]]), + onp.array([[[0, 0, 3, 0], + [0, 0, 0, 8], + [0, 0, 0, 0]], + [[0, 0, 6, 0], + [0, 0, 0, 2], + [0, 0, 0, 0]]])) + tests[-2, 0] = (onp.array([[[1, 6, 2], + [5, 1, 0], + [9, 0, 0]], + [[4, 9, 5], + [8, 4, 0], + [3, 0, 0]]]), + onp.array([[[1, 0, 0, 0], + [5, 6, 0, 0], + [9, 1, 2, 0]], + [[4, 0, 0, 0], + [8, 9, 0, 0], + [3, 4, 5, 0]]])) + tests[-1, 1] = (onp.array([[[2, 7, 3], + [1, 6, 2], + [5, 1, 0]], + [[5, 1, 6], + [4, 9, 5], + [8, 4, 0]]]), + onp.array([[[1, 2, 0, 0], + [5, 6, 7, 0], + [0, 1, 2, 3]], + [[4, 5, 0, 0], + [8, 9, 1, 0], + [0, 4, 5, 6]]])) + tests[0, 3] = (onp.array([[[4, 0, 0], + [3, 8, 0], + [2, 7, 3], + [1, 6, 2]], + [[7, 0, 0], + [6, 2, 0], + [5, 1, 6], + [4, 9, 5]]]), + onp.array([[[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0, 2, 3]], + [[4, 5, 6, 7], + [0, 9, 1, 2], + [0, 0, 5, 6]]])) + return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('data_type', [onp.int]) +def test_matrix_set_diag(data_type): + """ + Feature: ALL TO ALL + Description: test geneal matrix cases for matrix_set_diag in pynative mode + Expectation: the result match expected_diag_matrix. + """ + onp.random.seed(0) + context.set_context(mode=context.PYNATIVE_MODE) + for align in ALIGNMENT_LIST: + for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]: + for k_vec, (diagonal, banded_mat) in tests.items(): + mask = banded_mat[0] == 0 + input_mat = onp.random.randint(10, size=mask.shape) + expected_diag_matrix = input_mat * mask + banded_mat[0] + output = msp.ops_wrapper.matrix_set_diag( + Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align) + match_array(output.asnumpy(), expected_diag_matrix) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('data_type', [onp.int]) +def test_graph_matrix_set_diag(data_type): + """ + Feature: ALL TO ALL + Description: test general matrix cases for matrix_set_diag in graph mode + Expectation: the result match expected_diag_matrix. + """ + onp.random.seed(0) + context.set_context(mode=context.GRAPH_MODE) + for align in ALIGNMENT_LIST: + for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]: + for k_vec, (diagonal, banded_mat) in tests.items(): + mask = banded_mat[0] == 0 + input_mat = onp.random.randint(10, size=mask.shape) + expected_diag_matrix = input_mat * mask + banded_mat[0] + output = msp.ops_wrapper.matrix_set_diag( + Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align) + match_array(output.asnumpy(), expected_diag_matrix)