forked from mindspore-Ecosystem/mindspore
!29182 add kernel matrix_set_diag of backend cpu
Merge pull request !29182 from zhuzhongrui/pub_master4
This commit is contained in:
commit
562ac8be54
|
@ -102,6 +102,26 @@ const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
|
|||
{FusionType::DROPOUT_DOMASKV3D, "DropOutDoMaskV3D"},
|
||||
{FusionType::UNKNOWN_FUSION_TYPE, ""}};
|
||||
|
||||
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> GetAlignments(const std::string &alignment) {
|
||||
auto alignment_iter = MatrixDiag::AlignmentMap.find(alignment);
|
||||
if (alignment_iter == MatrixDiag::AlignmentMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For current kernel, input alignment is invalid: " << alignment
|
||||
<< ". please limit it to {RIGHT_LEFT, LEFT_RIGHT, RIGHT_RIGHT, LEFT_LEFT}";
|
||||
}
|
||||
return alignment_iter->second;
|
||||
}
|
||||
|
||||
int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols,
|
||||
const std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> &alignment) {
|
||||
bool right_align_super_diagonal = (alignment.first == MatrixDiag::RIGHT);
|
||||
bool right_align_sub_diagonal = (alignment.second == MatrixDiag::RIGHT);
|
||||
const bool right_align =
|
||||
(diag_index >= 0 && right_align_super_diagonal) || (diag_index <= 0 && right_align_sub_diagonal);
|
||||
const int diag_len = std::min(inner_rows + std::min(0, diag_index), inner_cols - std::max(0, diag_index));
|
||||
const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
std::string GetFusionNameByType(const kernel::FusionType &type) {
|
||||
auto iter = fusion_type_name_maps.find(type);
|
||||
if (iter == fusion_type_name_maps.end()) {
|
||||
|
|
|
@ -46,6 +46,17 @@ constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
|
|||
|
||||
const std::vector<std::string> support_devices = {"aicore", "aicpu", "cuda"};
|
||||
|
||||
// an enum to indicate a vector or matrix alignment direction.
|
||||
// real_data: [1,2,3] left_align: [1,2,3,0] right_align:[0,1,2,3]
|
||||
namespace MatrixDiag {
|
||||
enum Alignment { RIGHT = 0, LEFT = 1 };
|
||||
static const mindspore::HashMap<std::string, std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment>> AlignmentMap{
|
||||
{"RIGHT_LEFT", {MatrixDiag::RIGHT, MatrixDiag::LEFT}},
|
||||
{"LEFT_RIGHT", {MatrixDiag::LEFT, MatrixDiag::RIGHT}},
|
||||
{"RIGHT_RIGHT", {MatrixDiag::RIGHT, MatrixDiag::RIGHT}},
|
||||
{"LEFT_LEFT", {MatrixDiag::LEFT, MatrixDiag::LEFT}}};
|
||||
} // namespace MatrixDiag
|
||||
|
||||
struct KernelMetaInfo {
|
||||
uintptr_t func_stub_;
|
||||
uint32_t block_dim_;
|
||||
|
@ -72,6 +83,53 @@ class KernelMeta {
|
|||
std::unordered_map<std::string, std::string> kernel_meta_map_;
|
||||
};
|
||||
|
||||
class MatrixInfo {
|
||||
public:
|
||||
explicit MatrixInfo(size_t max_index, const std::vector<size_t> &matrix_shapes)
|
||||
: max_index_(max_index), shapes_(matrix_shapes) {
|
||||
current_indexes_.resize(shapes_.size(), 0);
|
||||
}
|
||||
~MatrixInfo() = default;
|
||||
bool SetIndex(size_t start, size_t end) {
|
||||
// check data from start to end whether valid.
|
||||
if (start < min_index || end > max_index_ || start >= end) {
|
||||
return false;
|
||||
}
|
||||
// initial current indexes.
|
||||
int last_rank = SizeToInt(current_indexes_.size()) - 1;
|
||||
for (int i = last_rank; start != 0 && i >= 0; --i) {
|
||||
current_indexes_[i] = start % shapes_.at(i);
|
||||
start = start / shapes_.at(i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::vector<size_t> IndexIterator() {
|
||||
if (is_first_iterator_) {
|
||||
is_first_iterator_ = false;
|
||||
return current_indexes_;
|
||||
}
|
||||
size_t last_rank = current_indexes_.size() - 1;
|
||||
current_indexes_[last_rank]++;
|
||||
for (size_t i = last_rank; current_indexes_.at(i) >= shapes_.at(i) && i > 0; --i) {
|
||||
current_indexes_[i] = 0;
|
||||
current_indexes_[i - 1] += 1;
|
||||
}
|
||||
is_first_iterator_ = false;
|
||||
return current_indexes_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_first_iterator_{true};
|
||||
size_t min_index{0};
|
||||
size_t max_index_{1};
|
||||
std::vector<size_t> shapes_;
|
||||
std::vector<size_t> current_indexes_;
|
||||
};
|
||||
using MatrixInfoPtr = std::shared_ptr<MatrixInfo>;
|
||||
|
||||
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> GetAlignments(const std::string &alignment);
|
||||
int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols,
|
||||
const std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> &alignment);
|
||||
std::string GetCompilerCachePath();
|
||||
bool CheckCache(const std::string &kernel_name);
|
||||
KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor);
|
||||
|
|
|
@ -99,6 +99,7 @@ constexpr char MODE[] = "mode";
|
|||
constexpr char UNIT_DIAGONAL[] = "unit_diagonal";
|
||||
constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors";
|
||||
constexpr char ADJOINT[] = "adjoint";
|
||||
constexpr char ALIGNMENT[] = "alignment";
|
||||
|
||||
struct ParallelSearchInfo {
|
||||
double min_cost_time{DBL_MAX};
|
||||
|
@ -111,7 +112,7 @@ struct ParallelSearchInfo {
|
|||
class CpuDynamicKernel : public device::DynamicKernel {
|
||||
public:
|
||||
explicit CpuDynamicKernel(const CNodePtr &cnode_ptr) : DynamicKernel(nullptr, cnode_ptr) {}
|
||||
~CpuDynamicKernel() = default;
|
||||
~CpuDynamicKernel() override = default;
|
||||
|
||||
void UpdateArgs() override;
|
||||
void PostExecute() final { MS_LOG(EXCEPTION) << "`PostExecute()` should not invoked with cpu backend"; };
|
||||
|
|
|
@ -45,7 +45,8 @@ void CholeskyCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size
|
|||
*col = shape.at(shape.size() - kColIndex);
|
||||
}
|
||||
if (*row != *col) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid: " << *row << ", " << *col;
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid. "
|
||||
<< "cholesky expects a square matrix. but input shape is:" << *row << ", " << *col;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/matrix_set_diag_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "common/thread_pool.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void MatrixSetDiagCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
constexpr size_t required_input_nums = 3;
|
||||
constexpr size_t required_output_nums = 1;
|
||||
if (AnfAlgo::GetInputNum(kernel_node) != required_input_nums ||
|
||||
AnfAlgo::GetOutputTensorNum(kernel_node) != required_output_nums) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the input nums are required to [input, diagonal, "
|
||||
"k, alignment] for 3 and the output nums is require to 1.";
|
||||
}
|
||||
|
||||
// invalid alignment will throw an exception.
|
||||
auto alignment = AnfAlgo::GetNodeAttr<std::string>(kernel_node, ALIGNMENT);
|
||||
alignment_ = GetAlignments(alignment);
|
||||
constexpr int input_index = 0;
|
||||
constexpr int diag_index = 1;
|
||||
constexpr int diag_k_index = 2;
|
||||
constexpr int output_index = 0;
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index);
|
||||
auto diag_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_index);
|
||||
auto diag_k_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, diag_k_index);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_index);
|
||||
|
||||
constexpr int temporary_2d_dim = 2;
|
||||
constexpr int temporary_1d_dim = 1;
|
||||
if (SizeToInt(input_shape.size()) < temporary_2d_dim || SizeToInt(diag_shape.size()) < temporary_1d_dim ||
|
||||
input_shape != output_shape) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of input is invalid for input shape greater than 2D, diag shape "
|
||||
"greater than 1D, input shape should equal to output shape.";
|
||||
}
|
||||
if (SizeToInt(diag_k_shape.size()) != temporary_1d_dim) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's dim should be limited to range (k[0],k[1]).";
|
||||
}
|
||||
int input_rank = SizeToInt(input_shape.size());
|
||||
for (int i = 0; i < input_rank - temporary_2d_dim; ++i) {
|
||||
outer_batch_ *= SizeToInt(input_shape.at(i));
|
||||
}
|
||||
input_shape_ = input_shape;
|
||||
inner_rows_ = SizeToInt(input_shape.at(input_rank - temporary_2d_dim));
|
||||
inner_cols_ = SizeToInt(input_shape.at(input_rank - temporary_1d_dim));
|
||||
|
||||
expected_num_diags_ =
|
||||
SizeToInt(diag_shape.size()) == input_rank ? SizeToInt(diag_shape.at(input_rank - temporary_2d_dim)) : 1;
|
||||
|
||||
data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool MatrixSetDiagCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspaces,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (data_type_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, workspaces, outputs);
|
||||
} else if (data_type_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, workspaces, outputs);
|
||||
} else if (data_type_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, workspaces, outputs);
|
||||
} else if (data_type_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int>(inputs, workspaces, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the data_type of input should be float16, float32, float64, int, int32."
|
||||
" but got "
|
||||
<< TypeIdToType(data_type_)->ToString();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixSetDiagCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input = inputs.at(0);
|
||||
auto diag = inputs.at(1);
|
||||
constexpr int diag_k_index = 2;
|
||||
auto k = inputs.at(diag_k_index);
|
||||
auto output = outputs.at(0);
|
||||
|
||||
T *input_addr = reinterpret_cast<T *>(input->addr);
|
||||
T *diag_addr = reinterpret_cast<T *>(diag->addr);
|
||||
int *diag_k_addr = reinterpret_cast<int *>(k->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(output->addr);
|
||||
lower_diag_index_ = diag_k_addr[0];
|
||||
upper_diag_index_ = diag_k_addr[1];
|
||||
is_single_diag_ = (lower_diag_index_ == upper_diag_index_);
|
||||
if (lower_diag_index_ <= -inner_rows_ || lower_diag_index_ >= inner_cols_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's lower_diag_index is invalid, which must be between "
|
||||
<< -inner_rows_ << " and " << inner_cols_;
|
||||
}
|
||||
if (upper_diag_index_ <= -inner_rows_ || upper_diag_index_ >= inner_cols_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's upper_diag_index is invalid, which must be between "
|
||||
<< -inner_rows_ << " and " << inner_cols_;
|
||||
}
|
||||
|
||||
if (lower_diag_index_ > upper_diag_index_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's lower_diag_index_ must less than upper_diag_index "
|
||||
<< lower_diag_index_ << " < " << upper_diag_index_;
|
||||
}
|
||||
num_diags_ = upper_diag_index_ - lower_diag_index_ + 1;
|
||||
if (lower_diag_index_ != upper_diag_index_ && expected_num_diags_ != num_diags_) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "For '" << kernel_name_
|
||||
<< "', the dimension of diag_region's lower_diag_index and upper_diag_index are not consistent with input shape.";
|
||||
}
|
||||
max_diag_len_ = std::min(inner_rows_ + std::min(upper_diag_index_, 0), inner_cols_ - std::max(lower_diag_index_, 0));
|
||||
|
||||
// copy input to output first, then set diagonal value to output.
|
||||
(void)memcpy_s(output_addr, output->size, input_addr, input->size);
|
||||
std::vector<common::Task> tasks;
|
||||
// an arg which depends on hardware.
|
||||
auto thread_pool = GetActorMgrInnerThreadPool();
|
||||
size_t task_nums = thread_pool->GetKernelThreadNum();
|
||||
if (task_nums == 0) {
|
||||
MS_LOG(EXCEPTION) << "MatrixSetDiagCPUKernel get kernel thread_pool nums, but kernel thread is 0!";
|
||||
}
|
||||
tasks.reserve(task_nums);
|
||||
size_t max_index = IntToSize(outer_batch_ * inner_rows_ * inner_cols_);
|
||||
size_t region = IntToSize(std::div(SizeToInt(max_index), SizeToInt(task_nums)).quot);
|
||||
region = (region == 0) ? max_index : region;
|
||||
for (size_t start = 0; start < max_index; start += region) {
|
||||
size_t end = start + region;
|
||||
if (end > max_index) {
|
||||
end = max_index;
|
||||
}
|
||||
(void)tasks.emplace_back([this, max_index, start, end, diag_addr, output_addr]() {
|
||||
MatrixInfoPtr matrix_info = std::make_shared<MatrixInfo>(max_index, input_shape_);
|
||||
if (!matrix_info->SetIndex(start, end)) {
|
||||
MS_LOG(EXCEPTION) << "current data indexes are invalid : [" << start << ", " << end
|
||||
<< "]. you should limit them in [0, " << max_index << "].";
|
||||
}
|
||||
auto get_out_batch = [](const std::vector<size_t> ¤t_indexes) {
|
||||
constexpr size_t last_two_dims = 2;
|
||||
int out_batch = 1;
|
||||
for (size_t i = 0; i < current_indexes.size() - last_two_dims; ++i) {
|
||||
out_batch *= (SizeToInt(current_indexes.at(i)) + 1);
|
||||
}
|
||||
size_t inner_row = current_indexes.at(current_indexes.size() - last_two_dims);
|
||||
size_t inner_col = current_indexes.at(current_indexes.size() - 1);
|
||||
std::tuple<int, int, int> flatten_3d_shape = std::make_tuple(out_batch - 1, inner_row, inner_col);
|
||||
return flatten_3d_shape;
|
||||
};
|
||||
for (size_t inner = start; inner < end; ++inner) {
|
||||
std::vector<size_t> current_indexes = matrix_info->IndexIterator();
|
||||
auto flatten_3d_shape = get_out_batch(current_indexes);
|
||||
int batch = std::get<0>(flatten_3d_shape);
|
||||
int m = std::get<1>(flatten_3d_shape);
|
||||
constexpr size_t col_index = 2;
|
||||
int n = std::get<col_index>(flatten_3d_shape);
|
||||
int d = n - m;
|
||||
if (is_single_diag_) {
|
||||
if (d == upper_diag_index_) {
|
||||
output_addr[inner] = diag_addr[batch * max_diag_len_ + n - std::max(upper_diag_index_, 0)];
|
||||
}
|
||||
} else {
|
||||
int diag_index = upper_diag_index_ - d;
|
||||
int offset = CalDiagOffset(d, max_diag_len_, inner_rows_, inner_cols_, alignment_);
|
||||
int index_in_diag = n - std::max(d, 0) + offset;
|
||||
if (d >= lower_diag_index_ && d <= upper_diag_index_) {
|
||||
output_addr[inner] =
|
||||
diag_addr[batch * num_diags_ * max_diag_len_ + diag_index * max_diag_len_ + index_in_diag];
|
||||
}
|
||||
}
|
||||
}
|
||||
return common::SUCCESS;
|
||||
});
|
||||
}
|
||||
ParallelLaunch(tasks);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MatrixSetDiagCPUKernel : public CPUKernel {
|
||||
public:
|
||||
MatrixSetDiagCPUKernel() = default;
|
||||
~MatrixSetDiagCPUKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
int lower_diag_index_{0};
|
||||
int upper_diag_index_{0};
|
||||
int inner_rows_{0};
|
||||
int inner_cols_{0};
|
||||
int num_diags_{0};
|
||||
int expected_num_diags_{0};
|
||||
int max_diag_len_{0};
|
||||
int outer_batch_{1};
|
||||
bool is_single_diag_{true};
|
||||
std::vector<size_t> input_shape_;
|
||||
// <super_matrix_diag_align, sub_matrix_diag_align>
|
||||
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
|
||||
TypeId data_type_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MatrixSetDiagCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MatrixSetDiagCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MatrixSetDiagCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(MatrixSetDiag,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MatrixSetDiagCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_KERNEL_H_
|
|
@ -14,4 +14,4 @@
|
|||
# ============================================================================
|
||||
"""Scipy-like interfaces in mindspore."""
|
||||
|
||||
from . import linalg, optimize, sparse
|
||||
from . import linalg, optimize, sparse, ops_wrapper
|
||||
|
|
|
@ -332,3 +332,25 @@ class LUSolver(PrimitiveWithInfer):
|
|||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class MatrixSetDiag(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API to set a [..., M, N] matrix's diagonals by range[k[0], k[1]].
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, alignment: str):
|
||||
super().__init__(name="MatrixSetDiag")
|
||||
self.init_prim_io_names(inputs=['input_x', 'diagonal', 'k'], outputs=['output'])
|
||||
self.alignment = validator.check_value_type("alignment", alignment, [str], self.name)
|
||||
|
||||
def __infer__(self, input_x, diagonal, k):
|
||||
in_shape = list(input_x['shape'])
|
||||
in_dtype = input_x['dtype']
|
||||
output = {
|
||||
'shape': tuple(in_shape),
|
||||
'dtype': in_dtype,
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Linear algebra submodule"""
|
||||
from .. import numpy as mnp
|
||||
from .ops import MatrixSetDiag
|
||||
from ..common import dtype as mstype
|
||||
from .utils_const import _raise_value_error
|
||||
|
||||
|
||||
def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"):
|
||||
"""
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
Given `input` and `diagonal`, this operation returns a tensor with the same shape and values as `input`,
|
||||
except for the specified diagonals of the innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or `k[0] == k[1]`,
|
||||
`diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`. Otherwise, it has `r+1` dimensions
|
||||
`[I, J, ..., L, num_diags, max_diag_len]`. `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
|
||||
`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
|
||||
`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` The output is a tensor of rank `k+1` with
|
||||
dimensions `[I, J, ..., L, M, N]`. If `k` is scalar or `k[0] == k[1]`:
|
||||
```
|
||||
output[i, j, ..., l, m, n]
|
||||
= diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
|
||||
input[i, j, ..., l, m, n] ; otherwise
|
||||
```
|
||||
Otherwise,
|
||||
```
|
||||
output[i, j, ..., l, m, n]
|
||||
= diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
|
||||
input[i, j, ..., l, m, n] ; otherwise
|
||||
```
|
||||
where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0) + offset`.
|
||||
`offset` is zero except when the alignment of the diagonal is to the right.
|
||||
```
|
||||
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
|
||||
and `d >= 0`) or
|
||||
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
|
||||
and `d <= 0`)
|
||||
0 ; otherwise
|
||||
```
|
||||
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): a :math:`(..., M, N)` matrix to be set diag.
|
||||
diagonal (Tensor): a :math`(..., max_diag_len)`, or `(..., num_diags, max_diag_len)` vector to be placed to
|
||||
output's diags.
|
||||
k (Tensor): a scalar or 1D list. it's max length is to which indicates the diag's lower index and upper index.
|
||||
(k[0], k[1]).
|
||||
alignment (str): Some diagonals are shorter than `max_diag_len` and need to be padded.
|
||||
`align` is a string specifying how superdiagonals and subdiagonals should be aligned,
|
||||
respectively. There are four possible alignments: "RIGHT_LEFT" (default),
|
||||
"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals to
|
||||
the right (left-pads the row) and subdiagonals to the left (right-pads the row).
|
||||
|
||||
Returns:
|
||||
- Tensor, :math:`(...,M, N)`. a batched matrix with the same shape and values as `input`,
|
||||
except for the specified diagonals of the innermost matrices.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.ops_wrapper import matrix_set_diag
|
||||
>>> input_x = Tensor(
|
||||
>>> onp.array([[[7, 7, 7, 7],[7, 7, 7, 7], [7, 7, 7, 7]],
|
||||
>>> [[7, 7, 7, 7],[7, 7, 7, 7],[7, 7, 7, 7]]])).astype(onp.int)
|
||||
>>> diagonal = Tensor(onp.array([[1, 2, 3],[4, 5, 6]])).astype(onp.int)
|
||||
>>> output = matrix_set_diag(input_x, diagonal)
|
||||
>>> print(output)
|
||||
>>> [[[1 7 7 7]
|
||||
[7 2 7 7]
|
||||
[7 7 3 7]]
|
||||
|
||||
[[4 7 7 7]
|
||||
[7 5 7 7]
|
||||
[7 7 6 7]]
|
||||
"""
|
||||
matrix_set_diag_net = MatrixSetDiag(alignment)
|
||||
k_vec = mnp.zeros((2,), dtype=mstype.int32)
|
||||
if isinstance(k, int):
|
||||
k_vec += k
|
||||
elif isinstance(k, (list, tuple)):
|
||||
k_vec = k
|
||||
else:
|
||||
_raise_value_error("input k to indicate diagonal region is invalid.")
|
||||
output = matrix_set_diag_net(input_x, diagonal, k_vec)
|
||||
return output
|
|
@ -54,16 +54,16 @@ def test_block_diag(args):
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (50, 50), (2, 5, 5)])
|
||||
def test_inv(dtype, shape):
|
||||
def test_inv(data_type, shape):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for inv
|
||||
Expectation: the result match numpy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
x = create_full_rank_matrix(shape, dtype)
|
||||
x = create_full_rank_matrix(shape, data_type)
|
||||
|
||||
ms_res = msp.linalg.inv(Tensor(x))
|
||||
scipy_res = onp.linalg.inv(x)
|
||||
|
@ -76,14 +76,14 @@ def test_inv(dtype, shape):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_cholesky(n: int, lower: bool, dtype: Generic):
|
||||
@pytest.mark.parametrize('data_type', [onp.float64])
|
||||
def test_cholesky(n: int, lower: bool, data_type: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), dtype)
|
||||
a = create_sym_pos_matrix((n, n), data_type)
|
||||
tensor_a = Tensor(a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-8
|
||||
|
@ -98,14 +98,14 @@ def test_cholesky(n: int, lower: bool, dtype: Generic):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_cho_factor(n: int, lower: bool, dtype: Generic):
|
||||
@pytest.mark.parametrize('data_type', [onp.float64])
|
||||
def test_cho_factor(n: int, lower: bool, data_type: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), dtype)
|
||||
a = create_sym_pos_matrix((n, n), data_type)
|
||||
tensor_a = Tensor(a)
|
||||
msp_c, _ = msp.linalg.cho_factor(tensor_a, lower=lower)
|
||||
if lower:
|
||||
|
@ -121,15 +121,15 @@ def test_cho_factor(n: int, lower: bool, dtype: Generic):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_cholesky_solver(n: int, lower: bool, dtype):
|
||||
@pytest.mark.parametrize('data_type', [onp.float64])
|
||||
def test_cholesky_solver(n: int, lower: bool, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky solver [N,N]
|
||||
Expectation: the result match scipy cholesky_solve
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), dtype)
|
||||
b = onp.ones((n, 1), dtype=dtype)
|
||||
a = create_sym_pos_matrix((n, n), data_type)
|
||||
b = onp.ones((n, 1), dtype=data_type)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
osp_c, lower = osp.linalg.cho_factor(a, lower=lower)
|
||||
|
@ -149,10 +149,10 @@ def test_cholesky_solver(n: int, lower: bool, dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||
@pytest.mark.parametrize('dtype',
|
||||
@pytest.mark.parametrize('data_type',
|
||||
[(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"),
|
||||
(onp.float64, "d")])
|
||||
def test_eigh(n: int, dtype):
|
||||
def test_eigh(n: int, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
|
||||
|
@ -160,11 +160,11 @@ def test_eigh(n: int, dtype):
|
|||
"""
|
||||
# test for real scalar float
|
||||
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
|
||||
rtol = tol[dtype[1]][0]
|
||||
atol = tol[dtype[1]][1]
|
||||
A = create_sym_pos_matrix([n, n], dtype[0])
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=False)
|
||||
rtol = tol[data_type[1]][0]
|
||||
atol = tol[data_type[1]][1]
|
||||
A = create_sym_pos_matrix([n, n], data_type[0])
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)),
|
||||
rtol,
|
||||
atol)
|
||||
|
@ -172,8 +172,8 @@ def test_eigh(n: int, dtype):
|
|||
rtol,
|
||||
atol)
|
||||
# test for real scalar float no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(dtype[0])), lower=False, eigvals_only=True)
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(data_type[0])), lower=False, eigvals_only=True)
|
||||
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
|
@ -183,8 +183,8 @@ def test_eigh(n: int, dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||
@pytest.mark.parametrize('dtype', [(onp.complex64, "f"), (onp.complex128, "d")])
|
||||
def test_eigh_complex(n: int, dtype):
|
||||
@pytest.mark.parametrize('data_type', [(onp.complex64, "f"), (onp.complex128, "d")])
|
||||
def test_eigh_complex(n: int, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N]
|
||||
|
@ -192,9 +192,9 @@ def test_eigh_complex(n: int, dtype):
|
|||
"""
|
||||
# test case for complex
|
||||
tol = {"f": (1e-3, 1e-4), "d": (1e-5, 1e-8)}
|
||||
rtol = tol[dtype[1]][0]
|
||||
atol = tol[dtype[1]][1]
|
||||
A = onp.array(onp.random.rand(n, n), dtype=dtype[0])
|
||||
rtol = tol[data_type[1]][0]
|
||||
atol = tol[data_type[1]][1]
|
||||
A = onp.array(onp.random.rand(n, n), dtype=data_type[0])
|
||||
for i in range(0, n):
|
||||
for j in range(0, n):
|
||||
if i == j:
|
||||
|
@ -203,16 +203,16 @@ def test_eigh_complex(n: int, dtype):
|
|||
A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1))
|
||||
sym_al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T)
|
||||
sym_au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T)
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(dtype[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(dtype[0])), lower=False, eigvals_only=False)
|
||||
msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(data_type[0])), lower=True, eigvals_only=False)
|
||||
msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(data_type[0])), lower=False, eigvals_only=False)
|
||||
assert onp.allclose(sym_al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()),
|
||||
onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(sym_au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()),
|
||||
onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
# test for real scalar complex no vector
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(dtype[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(dtype[0])), lower=False, eigvals_only=True)
|
||||
msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_al).astype(data_type[0])), lower=True, eigvals_only=True)
|
||||
msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_au).astype(data_type[0])), lower=False, eigvals_only=True)
|
||||
assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
|
@ -222,14 +222,14 @@ def test_eigh_complex(n: int, dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (5, 10), (20, 20)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu(shape: (int, int), dtype):
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_lu(shape: (int, int), data_type):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_random_rank_matrix(shape, dtype)
|
||||
a = create_random_rank_matrix(shape, data_type)
|
||||
s_p, s_l, s_u = osp.linalg.lu(a)
|
||||
tensor_a = Tensor(a)
|
||||
m_p, m_l, m_u = msp.linalg.lu(tensor_a)
|
||||
|
@ -245,14 +245,14 @@ def test_lu(shape: (int, int), dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_batch_lu(shape, dtype):
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_batch_lu(shape, data_type):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
b_a = create_random_rank_matrix(shape, dtype)
|
||||
b_a = create_random_rank_matrix(shape, data_type)
|
||||
b_s_p = list()
|
||||
b_s_l = list()
|
||||
b_s_u = list()
|
||||
|
@ -279,14 +279,14 @@ def test_batch_lu(shape, dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 10, 20])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu_factor(n: int, dtype):
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_lu_factor(n: int, data_type):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
a = create_full_rank_matrix((n, n), data_type)
|
||||
s_lu, s_pivots = osp.linalg.lu_factor(a)
|
||||
tensor_a = Tensor(a)
|
||||
m_lu, m_pivots = msp.linalg.lu_factor(tensor_a)
|
||||
|
@ -300,15 +300,15 @@ def test_lu_factor(n: int, dtype):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 10, 20])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu_solve(n: int, dtype):
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_lu_solve(n: int, data_type):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.random((n, 1)).astype(dtype)
|
||||
a = create_full_rank_matrix((n, n), data_type)
|
||||
b = onp.random.random((n, 1)).astype(data_type)
|
||||
s_lu, s_piv = osp.linalg.lu_factor(a)
|
||||
|
||||
tensor_a = Tensor(a)
|
||||
|
|
|
@ -0,0 +1,337 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""st for scipy.ops_wrapper."""
|
||||
import pytest
|
||||
import numpy as onp
|
||||
import mindspore.scipy as msp
|
||||
from mindspore import context, Tensor
|
||||
from tests.st.scipy_st.utils import match_array
|
||||
|
||||
DEFAULT_ALIGNMENT = "LEFT_LEFT"
|
||||
ALIGNMENT_LIST = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
|
||||
|
||||
|
||||
def repack_diagonals(packed_diagonals,
|
||||
diag_index,
|
||||
num_rows,
|
||||
num_cols,
|
||||
align=None):
|
||||
if align == DEFAULT_ALIGNMENT or align is None:
|
||||
return packed_diagonals
|
||||
align = align.split("_")
|
||||
d_lower, d_upper = diag_index
|
||||
batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1)
|
||||
max_diag_len = packed_diagonals.shape[-1]
|
||||
index = (slice(None),) * batch_dims
|
||||
repacked_diagonals = onp.zeros_like(packed_diagonals)
|
||||
for d_index in range(d_lower, d_upper + 1):
|
||||
diag_len = min(num_rows + min(0, d_index), num_cols - max(0, d_index))
|
||||
row_index = d_upper - d_index
|
||||
padding_len = max_diag_len - diag_len
|
||||
left_align = (d_index >= 0 and
|
||||
align[0] == "LEFT") or (d_index <= 0 and
|
||||
align[1] == "LEFT")
|
||||
extra_dim = tuple() if d_lower == d_upper else (row_index,)
|
||||
packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
|
||||
repacked_last_dim = (slice(None),) if left_align else (slice(
|
||||
padding_len, max_diag_len, 1),)
|
||||
packed_index = index + extra_dim + packed_last_dim
|
||||
repacked_index = index + extra_dim + repacked_last_dim
|
||||
|
||||
repacked_diagonals[repacked_index] = packed_diagonals[packed_index]
|
||||
return repacked_diagonals
|
||||
|
||||
|
||||
def repack_diagonals_in_tests(tests, num_rows, num_cols, align=None):
|
||||
# The original test cases are LEFT_LEFT aligned.
|
||||
if align == DEFAULT_ALIGNMENT or align is None:
|
||||
return tests
|
||||
new_tests = dict()
|
||||
# Loops through each case.
|
||||
for diag_index, (packed_diagonals, padded_diagonals) in tests.items():
|
||||
repacked_diagonals = repack_diagonals(
|
||||
packed_diagonals, diag_index, num_rows, num_cols, align=align)
|
||||
new_tests[diag_index] = (repacked_diagonals, padded_diagonals)
|
||||
|
||||
return new_tests
|
||||
|
||||
|
||||
def square_cases(align=None, dtype=None):
|
||||
mat = onp.array([[[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 1],
|
||||
[3, 4, 5, 6, 7],
|
||||
[8, 9, 1, 2, 3],
|
||||
[4, 5, 6, 7, 8]],
|
||||
[[9, 1, 2, 3, 4],
|
||||
[5, 6, 7, 8, 9],
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 1],
|
||||
[2, 3, 4, 5, 6]]], dtype=dtype)
|
||||
num_rows, num_cols = mat.shape[-2:]
|
||||
tests = dict()
|
||||
# tests[d_lower, d_upper] = packed_diagonals
|
||||
tests[-1, -1] = (onp.array([[6, 4, 1, 7],
|
||||
[5, 2, 8, 5]], dtype=dtype),
|
||||
onp.array([[[0, 0, 0, 0, 0],
|
||||
[6, 0, 0, 0, 0],
|
||||
[0, 4, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 7, 0]],
|
||||
[[0, 0, 0, 0, 0],
|
||||
[5, 0, 0, 0, 0],
|
||||
[0, 2, 0, 0, 0],
|
||||
[0, 0, 8, 0, 0],
|
||||
[0, 0, 0, 5, 0]]], dtype=dtype))
|
||||
tests[-4, -3] = (onp.array([[[8, 5],
|
||||
[4, 0]],
|
||||
[[6, 3],
|
||||
[2, 0]]], dtype=dtype),
|
||||
onp.array([[[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[8, 0, 0, 0, 0],
|
||||
[4, 5, 0, 0, 0]],
|
||||
[[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[6, 0, 0, 0, 0],
|
||||
[2, 3, 0, 0, 0]]], dtype=dtype))
|
||||
tests[-2, 1] = (onp.array([[[2, 8, 6, 3, 0],
|
||||
[1, 7, 5, 2, 8],
|
||||
[6, 4, 1, 7, 0],
|
||||
[3, 9, 6, 0, 0]],
|
||||
[[1, 7, 4, 1, 0],
|
||||
[9, 6, 3, 9, 6],
|
||||
[5, 2, 8, 5, 0],
|
||||
[1, 7, 4, 0, 0]]], dtype=dtype),
|
||||
onp.array([[[1, 2, 0, 0, 0],
|
||||
[6, 7, 8, 0, 0],
|
||||
[3, 4, 5, 6, 0],
|
||||
[0, 9, 1, 2, 3],
|
||||
[0, 0, 6, 7, 8]],
|
||||
[[9, 1, 0, 0, 0],
|
||||
[5, 6, 7, 0, 0],
|
||||
[1, 2, 3, 4, 0],
|
||||
[0, 7, 8, 9, 1],
|
||||
[0, 0, 4, 5, 6]]], dtype=dtype))
|
||||
tests[2, 4] = (onp.array([[[5, 0, 0],
|
||||
[4, 1, 0],
|
||||
[3, 9, 7]],
|
||||
[[4, 0, 0],
|
||||
[3, 9, 0],
|
||||
[2, 8, 5]]], dtype=dtype),
|
||||
onp.array([[[0, 0, 3, 4, 5],
|
||||
[0, 0, 0, 9, 1],
|
||||
[0, 0, 0, 0, 7],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[0, 0, 2, 3, 4],
|
||||
[0, 0, 0, 8, 9],
|
||||
[0, 0, 0, 0, 5],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]], dtype=dtype))
|
||||
|
||||
return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align)
|
||||
|
||||
|
||||
def tall_cases(align=None):
|
||||
mat = onp.array([[[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[9, 8, 7],
|
||||
[6, 5, 4]],
|
||||
[[3, 2, 1],
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[9, 8, 7]]])
|
||||
num_rows, num_cols = mat.shape[-2:]
|
||||
tests = dict()
|
||||
tests[0, 0] = (onp.array([[1, 5, 9],
|
||||
[3, 2, 6]]),
|
||||
onp.array([[[1, 0, 0],
|
||||
[0, 5, 0],
|
||||
[0, 0, 9],
|
||||
[0, 0, 0]],
|
||||
[[3, 0, 0],
|
||||
[0, 2, 0],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0]]]))
|
||||
tests[-4, -3] = (onp.array([[[9, 5],
|
||||
[6, 0]],
|
||||
[[7, 8],
|
||||
[9, 0]]]),
|
||||
onp.array([[[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[9, 0, 0],
|
||||
[6, 5, 0]],
|
||||
[[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[7, 0, 0],
|
||||
[9, 8, 0]]]))
|
||||
tests[-2, -1] = (onp.array([[[4, 8, 7],
|
||||
[7, 8, 4]],
|
||||
[[1, 5, 9],
|
||||
[4, 8, 7]]]),
|
||||
onp.array([[[0, 0, 0],
|
||||
[4, 0, 0],
|
||||
[7, 8, 0],
|
||||
[0, 8, 7],
|
||||
[0, 0, 4]],
|
||||
[[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[4, 5, 0],
|
||||
[0, 8, 9],
|
||||
[0, 0, 7]]]))
|
||||
tests[-2, 1] = (onp.array([[[2, 6, 0],
|
||||
[1, 5, 9],
|
||||
[4, 8, 7],
|
||||
[7, 8, 4]],
|
||||
[[2, 3, 0],
|
||||
[3, 2, 6],
|
||||
[1, 5, 9],
|
||||
[4, 8, 7]]]),
|
||||
onp.array([[[1, 2, 0],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[0, 8, 7],
|
||||
[0, 0, 4]],
|
||||
[[3, 2, 0],
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[0, 8, 9],
|
||||
[0, 0, 7]]]))
|
||||
tests[1, 2] = (onp.array([[[3, 0],
|
||||
[2, 6]],
|
||||
[[1, 0],
|
||||
[2, 3]]]),
|
||||
onp.array([[[0, 2, 3],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]],
|
||||
[[0, 2, 1],
|
||||
[0, 0, 3],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]]]))
|
||||
|
||||
return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align)
|
||||
|
||||
|
||||
def fat_cases(align=None):
|
||||
mat = onp.array([[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 1, 2, 3]],
|
||||
[[4, 5, 6, 7],
|
||||
[8, 9, 1, 2],
|
||||
[3, 4, 5, 6]]])
|
||||
num_rows, num_cols = mat.shape[-2:]
|
||||
tests = dict()
|
||||
tests[2, 2] = (onp.array([[3, 8],
|
||||
[6, 2]]),
|
||||
onp.array([[[0, 0, 3, 0],
|
||||
[0, 0, 0, 8],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 6, 0],
|
||||
[0, 0, 0, 2],
|
||||
[0, 0, 0, 0]]]))
|
||||
tests[-2, 0] = (onp.array([[[1, 6, 2],
|
||||
[5, 1, 0],
|
||||
[9, 0, 0]],
|
||||
[[4, 9, 5],
|
||||
[8, 4, 0],
|
||||
[3, 0, 0]]]),
|
||||
onp.array([[[1, 0, 0, 0],
|
||||
[5, 6, 0, 0],
|
||||
[9, 1, 2, 0]],
|
||||
[[4, 0, 0, 0],
|
||||
[8, 9, 0, 0],
|
||||
[3, 4, 5, 0]]]))
|
||||
tests[-1, 1] = (onp.array([[[2, 7, 3],
|
||||
[1, 6, 2],
|
||||
[5, 1, 0]],
|
||||
[[5, 1, 6],
|
||||
[4, 9, 5],
|
||||
[8, 4, 0]]]),
|
||||
onp.array([[[1, 2, 0, 0],
|
||||
[5, 6, 7, 0],
|
||||
[0, 1, 2, 3]],
|
||||
[[4, 5, 0, 0],
|
||||
[8, 9, 1, 0],
|
||||
[0, 4, 5, 6]]]))
|
||||
tests[0, 3] = (onp.array([[[4, 0, 0],
|
||||
[3, 8, 0],
|
||||
[2, 7, 3],
|
||||
[1, 6, 2]],
|
||||
[[7, 0, 0],
|
||||
[6, 2, 0],
|
||||
[5, 1, 6],
|
||||
[4, 9, 5]]]),
|
||||
onp.array([[[1, 2, 3, 4],
|
||||
[0, 6, 7, 8],
|
||||
[0, 0, 2, 3]],
|
||||
[[4, 5, 6, 7],
|
||||
[0, 9, 1, 2],
|
||||
[0, 0, 5, 6]]]))
|
||||
return mat, repack_diagonals_in_tests(tests, num_rows, num_cols, align)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [onp.int])
|
||||
def test_matrix_set_diag(data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test geneal matrix cases for matrix_set_diag in pynative mode
|
||||
Expectation: the result match expected_diag_matrix.
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for align in ALIGNMENT_LIST:
|
||||
for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]:
|
||||
for k_vec, (diagonal, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = onp.random.randint(10, size=mask.shape)
|
||||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = msp.ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_array(output.asnumpy(), expected_diag_matrix)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_type', [onp.int])
|
||||
def test_graph_matrix_set_diag(data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_set_diag in graph mode
|
||||
Expectation: the result match expected_diag_matrix.
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for align in ALIGNMENT_LIST:
|
||||
for _, tests in [square_cases(align, data_type), tall_cases(align), fat_cases(align)]:
|
||||
for k_vec, (diagonal, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = onp.random.randint(10, size=mask.shape)
|
||||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = msp.ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_array(output.asnumpy(), expected_diag_matrix)
|
Loading…
Reference in New Issue