!43951 [assistant][ops] Add MatrixTriangularSolve

Merge pull request !43951 from 张渝/MatrixTriangularSolve
This commit is contained in:
i-robot 2022-11-05 02:38:51 +00:00 committed by Gitee
commit 37011775a9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 539 additions and 0 deletions

View File

@ -0,0 +1,156 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h"
#include <Eigen/Dense>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
namespace mindspore {
namespace kernel {
using Eigen::ColMajor;
using Eigen::Dynamic;
using Eigen::Lower;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::UnitLower;
using Eigen::UnitUpper;
using Eigen::Upper;
template <typename T, int Major>
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
constexpr auto kSolveTriangularInputsNum = 2;
constexpr auto kSolveTriangularOutputsNum = 1;
constexpr auto kAVectorxDimNum = 1;
constexpr auto kAMatrixDimNum = 2;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
void MatrixTriangularSolveCpuKernelMod::InitShape(const CNodePtr &kernel_node) {
auto a_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto b_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
size_t a_dims = a_shape.size();
size_t aRowIndex = a_dims - kRowIndex;
m_ = a_shape[aRowIndex];
size_t b_sims = b_shape.size();
bool vector_b = b_sims == a_dims - 1;
if (vector_b) {
n_ = 1;
} else {
n_ = b_shape[b_sims - 1];
}
batch_ = 1;
for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) {
batch_ *= a_shape[batch];
}
}
void MatrixTriangularSolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
InitShape(kernel_node);
trans_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJOINT);
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "MatrixTriangularSolve does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename Derived_a, typename Derived_b, typename T>
inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b, T *output_addr, int m, int n,
bool lower, bool unit_diagonal) {
Map<Matrix<T, RowMajor>> output(output_addr, m, n);
if (unit_diagonal) {
if (lower) {
output.noalias() = a.template triangularView<UnitLower>().solve(b);
} else {
output.noalias() = a.template triangularView<UnitUpper>().solve(b);
}
} else {
if (lower) {
output.noalias() = a.template triangularView<Lower>().solve(b);
} else {
output.noalias() = a.template triangularView<Upper>().solve(b);
}
}
}
template <typename T>
bool MatrixTriangularSolveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);
auto a_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto b_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t a_batch_size = m_ * m_;
size_t b_batch_size = m_ * n_;
size_t output_batch_size = m_ * n_;
for (size_t i = 0; i < batch_; ++i) {
T *a_batch_addr = a_addr + i * a_batch_size;
T *b_batch_addr = b_addr + i * b_batch_size;
T *output_batch_addr = output_addr + i * output_batch_size;
Map<Matrix<T, RowMajor>> b(b_batch_addr, m_, n_);
if (trans_) {
Map<Matrix<T, ColMajor>> a(a_batch_addr, m_, m_);
auto a_conj = a.conjugate();
solve(a_conj, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_);
} else {
Map<Matrix<T, RowMajor>> a(a_batch_addr, m_, m_);
solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_);
}
}
return true;
}
std::vector<std::pair<KernelAttr, MatrixTriangularSolveCpuKernelMod::MatrixTriangularSolveFunc>>
MatrixTriangularSolveCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<std::complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<std::complex<double>>}};
std::vector<KernelAttr> MatrixTriangularSolveCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixTriangularSolveFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixTriangularSolve, MatrixTriangularSolveCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixTriangularSolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MatrixTriangularSolveCpuKernelMod() = default;
~MatrixTriangularSolveCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitShape(const CNodePtr &kernel_node);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using MatrixTriangularSolveFunc =
std::function<bool(MatrixTriangularSolveCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MatrixTriangularSolveFunc>> func_list_;
MatrixTriangularSolveFunc kernel_func_;
size_t m_{0};
size_t n_{0};
size_t batch_{1};
bool lower_{false};
bool trans_{false};
bool unit_diagonal_{false};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_

View File

@ -147,6 +147,7 @@ constexpr auto kTrilIndices = "TrilIndices";
constexpr auto kTrace = "Trace";
constexpr auto kTraceGrad = "TraceGrad";
constexpr auto kMatrixLogarithm = "MatrixLogarithm";
constexpr auto kMatrixTriangularSolve = "MatrixTriangularSolve";
// Arrays
constexpr auto kLeftShift = "LeftShift";
@ -1305,6 +1306,7 @@ GVAR_DEF(PrimitivePtr, kPrimEig, std::make_shared<Primitive>("Eig"));
GVAR_DEF(PrimitivePtr, kPrimEigh, std::make_shared<Primitive>("Eigh"));
GVAR_DEF(PrimitivePtr, kPrimQr, std::make_shared<Primitive>("Qr"));
GVAR_DEF(PrimitivePtr, kPrimMatrixLogarithm, std::make_shared<Primitive>(kMatrixLogarithm));
GVAR_DEF(PrimitivePtr, kPrimMatrixTriangularSolve, std::make_shared<Primitive>(kMatrixTriangularSolve));
// linalg
GVAR_DEF(PrimitivePtr, kPrimGeqrf, std::make_shared<Primitive>("Geqrf"));

View File

@ -0,0 +1,90 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "ops/matrix_triangular_solve.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MatrixTriangularSolveInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto matrix_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto rhs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input matrix rank", SizeToLong(matrix_shape.size()), kGreaterEqual, 2L,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("input rhs rank", SizeToLong(rhs_shape.size()), kGreaterEqual, 2L,
prim_name);
constexpr size_t offset = 2;
std::vector<int> matrix_last(matrix_shape.end() - offset, matrix_shape.end());
std::vector<int> rhs_last(rhs_shape.end() - offset, rhs_shape.end());
int64_t matrix_row = matrix_last[0];
int64_t matrix_col = matrix_last[1];
int64_t rhs_row = rhs_last[0];
for (size_t i = 0; i < matrix_shape.size() - offset; ++i) {
if (matrix_shape[i] != rhs_shape[i]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " shapes in batch dimension should be same, dim[" << i
<< "] are not the same, "
<< "while matrix is " << matrix_shape[i] << ", rhs is " << rhs_shape[i];
}
}
if (matrix_row != rhs_row) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
<< "got " << matrix_row << " and " << rhs_row << " , with matrix row " << matrix_row
<< ", rhs row " << rhs_row << ", matrix's row rank should be same as rhs's row rank";
}
if (matrix_row != matrix_col) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
<< "got " << matrix_row << " and " << matrix_col << " , with matrix row " << matrix_row
<< ", matrix col " << matrix_col
<< ". Inner-most 2 demision of input matrix must be square";
}
return std::make_shared<abstract::Shape>(rhs_shape);
}
TypePtr MatrixTriangularSolveInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
std::map<std::string, TypePtr> types;
(void)types.emplace("matrix", input_args[0]->BuildType());
(void)types.emplace("rhs", input_args[1]->BuildType());
const std::set<TypePtr> valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixTriangularSolve, BaseOperator);
AbstractBasePtr MatrixTriangularSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kTwo = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kTwo, primitive->name());
auto infer_type = MatrixTriangularSolveInferType(primitive, input_args);
auto infer_shape = MatrixTriangularSolveInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixTriangularSolve, prim::kPrimMatrixTriangularSolve, MatrixTriangularSolveInfer,
nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_TRIANGULAR_SOLVE_H_
#define MINDSPORE_CORE_OPS_MATRIX_TRIANGULAR_SOLVE_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixTriangularSolve = "MatrixTriangularSolve";
class MIND_API MatrixTriangularSolve : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixTriangularSolve);
/// \brief Constructor.
MatrixTriangularSolve() : BaseOperator(kNameMatrixTriangularSolve) { InitIOName({"matrix", "rhs"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MatrixTriangularSolve for the inputs.
void Init() const {}
};
abstract::AbstractBasePtr MatrixTriangularSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMatrixTriangularSolvePtr = std::shared_ptr<MatrixTriangularSolve>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_TRIANGULAR_SOLVE_H_

View File

@ -48,6 +48,7 @@ from mindspore.ops.operations.math_ops import CumulativeLogsumexp
from mindspore.ops.operations.math_ops import MatrixSolve
from mindspore.ops.operations.math_ops import MatrixPower
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.math_ops import MatrixTriangularSolve
from mindspore.ops.operations.math_ops import Betainc
from mindspore.ops.operations.math_ops import Cholesky
from mindspore.ops.operations.math_ops import CholeskySolve
@ -63,6 +64,7 @@ from mindspore.ops._utils.utils import is_shape_unknown, is_dim_unknown
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
from mindspore.ops._grad.grad_math_ops import binop_grad_common
from mindspore.ops.operations.array_ops import MatrixBandPart
transpose = P.Transpose()
dyn_shape_op = P.TensorShape()
@ -458,6 +460,61 @@ def get_brop_cumulative_logsumexp(self):
return bprop
@bprop_getters.register(MatrixTriangularSolve)
def get_bprop_matrix_triangular_solve(self):
"""Grad definition for 'MatrixTriangularSolve' operation"""
adjoint_a = self.adjoint
lower_a = self.lower
matrix_triangular_solve_op = P.MatrixTriangularSolve(lower=lower_a, adjoint=not adjoint_a)
mat_mul_2d_op = P.MatMul()
mat_mul_op = P.BatchMatMul()
real_op = P.Real()
imag_op = P.Imag()
neg_op = P.Neg()
complex_op = P.Complex()
matrix_band_part_op = MatrixBandPart()
def bprop(matrix, rhs, out, dout):
grad_rhs = matrix_triangular_solve_op(matrix, dout)
if matrix.dtype == mstype.complex64 or matrix.dtype == mstype.complex128:
grad_rhs_temp = _adjoint(grad_rhs)
out_temp = _adjoint(out)
else:
grad_rhs_temp = cholesky_transpose(grad_rhs)
out_temp = cholesky_transpose(out)
if adjoint_a:
if len(matrix.shape) == 2:
grad_matrix = mat_mul_2d_op(out, grad_rhs_temp)
grad_matrix = neg_op(grad_matrix)
else:
grad_matrix = mat_mul_op(out, grad_rhs_temp)
grad_matrix = neg_op(grad_matrix)
else:
if len(matrix.shape) == 2:
grad_matrix = mat_mul_2d_op(grad_rhs, out_temp)
grad_matrix = neg_op(grad_matrix)
else:
grad_matrix = mat_mul_op(grad_rhs, out_temp)
grad_matrix = neg_op(grad_matrix)
if lower_a:
if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), -1, 0)
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), -1, 0)
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
else:
grad_matrix = matrix_band_part_op(grad_matrix, -1, 0)
else:
if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), 0, -1)
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), 0, -1)
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
else:
grad_matrix = matrix_band_part_op(grad_matrix, 0, -1)
return (grad_matrix, grad_rhs)
return bprop
@bprop_getters.register(MatrixExp)
def get_bprop_matrix_exp(self):
"""Gegerate brop for MatrixExp"""

View File

@ -0,0 +1,43 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
batch_matmul_op_info = AiCPURegOp("BatchMatMul") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "output", "required") \
.attr("adj_x1", "bool") \
.attr("adj_x2", "bool") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(batch_matmul_op_info)
def _batch_matmul_aicpu():
"""BatchMatMul AiCPU register"""
return

View File

@ -0,0 +1,36 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixTriangularSolve op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_triangular_solve_op_info = AiCPURegOp("MatrixTriangularSolve") \
.fusion_type("OPAQUE") \
.attr("lower", "bool") \
.attr("adjoint", "bool") \
.input(0, "matrix", "required") \
.input(1, "rhs", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(matrix_triangular_solve_op_info)
def _matrix_triangular_solve_aicpu():
"""MatrixTriangularSolve aicpu register"""
return

View File

@ -6996,6 +6996,51 @@ class TrilIndices(Primitive):
validator.check_type_name("dtype", dtype, valid_values, self.name)
class MatrixTriangularSolve(Primitive):
r"""
Returns a new tensor with the solotion of a linear equation system with an upper or lower triangular matrix.
Args:
lower (bool): If true, the innermost matrices in `matrix` is are lower triangular. Default: True.
adjoint (bool): If true, solve with the adjoint of `matrix`. Default: False.
Inputs:
- **matrix** (Tensor) - Tensor of shape :math:`(*, M, M)`,
with float32, float64, complex64 and complex128 data type.
- **rhs** (Tensor) - Tensor of shape :math:`(*, M, N)`,
with float32, float64, complex64 and complex128 data type.
Outputs:
Tensor, has the shape of math:`(*, M, N)` and the same data type as `matrix`.
Raises:
TypeError: If `matrix` or `rhs` is not a Tensor.
TypeError: If `lower` or `adjoint` is not bool.
ValueError: If the batch sizes of `matrix`and `rhs` are not equal.
ValueError: If the inner-most 2 dimensions of `matrix` are not equal.
ValueError: If the second-last dimensions of `matrix`and `rhs` are not equal.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> matrix_triangular_solve = ops.MatrixTriangularSolve(lower=True, adjoint=False)
>>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
>>> b = np.array([[1, 0],[2, 2],[1, 5],[0, 3]])
>>> output = matrix_triangular_solve(Tensor(a, mindspore.float32), Tensor(b, mindspore.float32))
>>> print(output)
[[ 0.33333334 0. ]
[ 1.3333333 2. ]
[ 0.6666666 5. ]
[-2.3333333 -4. ]]
"""
@prim_attr_register
def __init__(self, lower=True, adjoint=False):
"""Initialize MatrixTriangularSolve"""
validator.check_value_type('adjoint', adjoint, [bool], self.name)
validator.check_value_type('lower', lower, [bool], self.name)
class CompareAndBitpack(Primitive):
"""
Compare values of `x` to `threshold` and pack resulting bits into a `uint8`.

View File

@ -29,6 +29,7 @@ from mindspore.ops import functional as F
from mindspore.ops.operations._grad_ops import IgammaGradA
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.operations.math_ops import Zeta, Igamma, Igammac
from mindspore.ops.operations.math_ops import MatrixTriangularSolve
from mindspore.ops.operations.sparse_ops import DenseToDenseSetOperation
from mindspore.ops.operations.sparse_ops import DenseToSparseSetOperation
@ -894,6 +895,11 @@ raise_set = [
'block': P.Trunc(),
'desc_inputs': [Tensor(np.array([[1.1, 2.2, -4.1]], np.float32))],
'skip': ['backward']}),
('MatrixTriangularSolve', {
'block': MatrixTriangularSolve(adjoint=False, lower=True),
'desc_inputs': [Tensor(np.array([4, 4, 4]).astype(np.float32)),
Tensor(np.array([4, 4, 4]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([4, 4, 4]).astype(np.float32))]}),
('Gcd', {
'block': GcdFunc(),
'desc_inputs': [Tensor(np.array([2, 5, 8]).astype(np.int32)),