!43951 [assistant][ops] Add MatrixTriangularSolve
Merge pull request !43951 from 张渝/MatrixTriangularSolve
This commit is contained in:
commit
37011775a9
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::Dynamic;
|
||||
using Eigen::Lower;
|
||||
using Eigen::Map;
|
||||
using Eigen::MatrixBase;
|
||||
using Eigen::RowMajor;
|
||||
using Eigen::UnitLower;
|
||||
using Eigen::UnitUpper;
|
||||
using Eigen::Upper;
|
||||
template <typename T, int Major>
|
||||
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
|
||||
constexpr auto kSolveTriangularInputsNum = 2;
|
||||
constexpr auto kSolveTriangularOutputsNum = 1;
|
||||
constexpr auto kAVectorxDimNum = 1;
|
||||
constexpr auto kAMatrixDimNum = 2;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
void MatrixTriangularSolveCpuKernelMod::InitShape(const CNodePtr &kernel_node) {
|
||||
auto a_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto b_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
|
||||
size_t a_dims = a_shape.size();
|
||||
size_t aRowIndex = a_dims - kRowIndex;
|
||||
m_ = a_shape[aRowIndex];
|
||||
size_t b_sims = b_shape.size();
|
||||
bool vector_b = b_sims == a_dims - 1;
|
||||
if (vector_b) {
|
||||
n_ = 1;
|
||||
} else {
|
||||
n_ = b_shape[b_sims - 1];
|
||||
}
|
||||
batch_ = 1;
|
||||
for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) {
|
||||
batch_ *= a_shape[batch];
|
||||
}
|
||||
}
|
||||
|
||||
void MatrixTriangularSolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
InitShape(kernel_node);
|
||||
trans_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJOINT);
|
||||
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "MatrixTriangularSolve does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename Derived_a, typename Derived_b, typename T>
|
||||
inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b, T *output_addr, int m, int n,
|
||||
bool lower, bool unit_diagonal) {
|
||||
Map<Matrix<T, RowMajor>> output(output_addr, m, n);
|
||||
if (unit_diagonal) {
|
||||
if (lower) {
|
||||
output.noalias() = a.template triangularView<UnitLower>().solve(b);
|
||||
} else {
|
||||
output.noalias() = a.template triangularView<UnitUpper>().solve(b);
|
||||
}
|
||||
} else {
|
||||
if (lower) {
|
||||
output.noalias() = a.template triangularView<Lower>().solve(b);
|
||||
} else {
|
||||
output.noalias() = a.template triangularView<Upper>().solve(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixTriangularSolveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);
|
||||
|
||||
auto a_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto b_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t a_batch_size = m_ * m_;
|
||||
size_t b_batch_size = m_ * n_;
|
||||
size_t output_batch_size = m_ * n_;
|
||||
|
||||
for (size_t i = 0; i < batch_; ++i) {
|
||||
T *a_batch_addr = a_addr + i * a_batch_size;
|
||||
T *b_batch_addr = b_addr + i * b_batch_size;
|
||||
T *output_batch_addr = output_addr + i * output_batch_size;
|
||||
|
||||
Map<Matrix<T, RowMajor>> b(b_batch_addr, m_, n_);
|
||||
if (trans_) {
|
||||
Map<Matrix<T, ColMajor>> a(a_batch_addr, m_, m_);
|
||||
auto a_conj = a.conjugate();
|
||||
solve(a_conj, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_);
|
||||
} else {
|
||||
Map<Matrix<T, RowMajor>> a(a_batch_addr, m_, m_);
|
||||
solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MatrixTriangularSolveCpuKernelMod::MatrixTriangularSolveFunc>>
|
||||
MatrixTriangularSolveCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<std::complex<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&MatrixTriangularSolveCpuKernelMod::LaunchKernel<std::complex<double>>}};
|
||||
|
||||
std::vector<KernelAttr> MatrixTriangularSolveCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MatrixTriangularSolveFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixTriangularSolve, MatrixTriangularSolveCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MatrixTriangularSolveCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MatrixTriangularSolveCpuKernelMod() = default;
|
||||
~MatrixTriangularSolveCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitShape(const CNodePtr &kernel_node);
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using MatrixTriangularSolveFunc =
|
||||
std::function<bool(MatrixTriangularSolveCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MatrixTriangularSolveFunc>> func_list_;
|
||||
MatrixTriangularSolveFunc kernel_func_;
|
||||
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
||||
size_t batch_{1};
|
||||
bool lower_{false};
|
||||
bool trans_{false};
|
||||
bool unit_diagonal_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
|
|
@ -147,6 +147,7 @@ constexpr auto kTrilIndices = "TrilIndices";
|
|||
constexpr auto kTrace = "Trace";
|
||||
constexpr auto kTraceGrad = "TraceGrad";
|
||||
constexpr auto kMatrixLogarithm = "MatrixLogarithm";
|
||||
constexpr auto kMatrixTriangularSolve = "MatrixTriangularSolve";
|
||||
|
||||
// Arrays
|
||||
constexpr auto kLeftShift = "LeftShift";
|
||||
|
@ -1305,6 +1306,7 @@ GVAR_DEF(PrimitivePtr, kPrimEig, std::make_shared<Primitive>("Eig"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimEigh, std::make_shared<Primitive>("Eigh"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimQr, std::make_shared<Primitive>("Qr"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixLogarithm, std::make_shared<Primitive>(kMatrixLogarithm));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixTriangularSolve, std::make_shared<Primitive>(kMatrixTriangularSolve));
|
||||
|
||||
// linalg
|
||||
GVAR_DEF(PrimitivePtr, kPrimGeqrf, std::make_shared<Primitive>("Geqrf"));
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "ops/matrix_triangular_solve.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr MatrixTriangularSolveInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto matrix_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto rhs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input matrix rank", SizeToLong(matrix_shape.size()), kGreaterEqual, 2L,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input rhs rank", SizeToLong(rhs_shape.size()), kGreaterEqual, 2L,
|
||||
prim_name);
|
||||
constexpr size_t offset = 2;
|
||||
std::vector<int> matrix_last(matrix_shape.end() - offset, matrix_shape.end());
|
||||
std::vector<int> rhs_last(rhs_shape.end() - offset, rhs_shape.end());
|
||||
int64_t matrix_row = matrix_last[0];
|
||||
int64_t matrix_col = matrix_last[1];
|
||||
int64_t rhs_row = rhs_last[0];
|
||||
for (size_t i = 0; i < matrix_shape.size() - offset; ++i) {
|
||||
if (matrix_shape[i] != rhs_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " shapes in batch dimension should be same, dim[" << i
|
||||
<< "] are not the same, "
|
||||
<< "while matrix is " << matrix_shape[i] << ", rhs is " << rhs_shape[i];
|
||||
}
|
||||
}
|
||||
if (matrix_row != rhs_row) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
|
||||
<< "got " << matrix_row << " and " << rhs_row << " , with matrix row " << matrix_row
|
||||
<< ", rhs row " << rhs_row << ", matrix's row rank should be same as rhs's row rank";
|
||||
}
|
||||
if (matrix_row != matrix_col) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
|
||||
<< "got " << matrix_row << " and " << matrix_col << " , with matrix row " << matrix_row
|
||||
<< ", matrix col " << matrix_col
|
||||
<< ". Inner-most 2 demision of input matrix must be square";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(rhs_shape);
|
||||
}
|
||||
|
||||
TypePtr MatrixTriangularSolveInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("matrix", input_args[0]->BuildType());
|
||||
(void)types.emplace("rhs", input_args[1]->BuildType());
|
||||
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
}
|
||||
} // namespace
|
||||
MIND_API_OPERATOR_IMPL(MatrixTriangularSolve, BaseOperator);
|
||||
AbstractBasePtr MatrixTriangularSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kTwo = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kTwo, primitive->name());
|
||||
auto infer_type = MatrixTriangularSolveInferType(primitive, input_args);
|
||||
auto infer_shape = MatrixTriangularSolveInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixTriangularSolve, prim::kPrimMatrixTriangularSolve, MatrixTriangularSolveInfer,
|
||||
nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MATRIX_TRIANGULAR_SOLVE_H_
|
||||
#define MINDSPORE_CORE_OPS_MATRIX_TRIANGULAR_SOLVE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMatrixTriangularSolve = "MatrixTriangularSolve";
|
||||
|
||||
class MIND_API MatrixTriangularSolve : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MatrixTriangularSolve);
|
||||
/// \brief Constructor.
|
||||
MatrixTriangularSolve() : BaseOperator(kNameMatrixTriangularSolve) { InitIOName({"matrix", "rhs"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MatrixTriangularSolve for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr MatrixTriangularSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMatrixTriangularSolvePtr = std::shared_ptr<MatrixTriangularSolve>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MATRIX_TRIANGULAR_SOLVE_H_
|
|
@ -48,6 +48,7 @@ from mindspore.ops.operations.math_ops import CumulativeLogsumexp
|
|||
from mindspore.ops.operations.math_ops import MatrixSolve
|
||||
from mindspore.ops.operations.math_ops import MatrixPower
|
||||
from mindspore.ops.operations.math_ops import Median
|
||||
from mindspore.ops.operations.math_ops import MatrixTriangularSolve
|
||||
from mindspore.ops.operations.math_ops import Betainc
|
||||
from mindspore.ops.operations.math_ops import Cholesky
|
||||
from mindspore.ops.operations.math_ops import CholeskySolve
|
||||
|
@ -63,6 +64,7 @@ from mindspore.ops._utils.utils import is_shape_unknown, is_dim_unknown
|
|||
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
|
||||
from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
|
||||
from mindspore.ops._grad.grad_math_ops import binop_grad_common
|
||||
from mindspore.ops.operations.array_ops import MatrixBandPart
|
||||
|
||||
transpose = P.Transpose()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
@ -458,6 +460,61 @@ def get_brop_cumulative_logsumexp(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(MatrixTriangularSolve)
|
||||
def get_bprop_matrix_triangular_solve(self):
|
||||
"""Grad definition for 'MatrixTriangularSolve' operation"""
|
||||
adjoint_a = self.adjoint
|
||||
lower_a = self.lower
|
||||
matrix_triangular_solve_op = P.MatrixTriangularSolve(lower=lower_a, adjoint=not adjoint_a)
|
||||
mat_mul_2d_op = P.MatMul()
|
||||
mat_mul_op = P.BatchMatMul()
|
||||
real_op = P.Real()
|
||||
imag_op = P.Imag()
|
||||
neg_op = P.Neg()
|
||||
complex_op = P.Complex()
|
||||
matrix_band_part_op = MatrixBandPart()
|
||||
|
||||
def bprop(matrix, rhs, out, dout):
|
||||
grad_rhs = matrix_triangular_solve_op(matrix, dout)
|
||||
if matrix.dtype == mstype.complex64 or matrix.dtype == mstype.complex128:
|
||||
grad_rhs_temp = _adjoint(grad_rhs)
|
||||
out_temp = _adjoint(out)
|
||||
else:
|
||||
grad_rhs_temp = cholesky_transpose(grad_rhs)
|
||||
out_temp = cholesky_transpose(out)
|
||||
if adjoint_a:
|
||||
if len(matrix.shape) == 2:
|
||||
grad_matrix = mat_mul_2d_op(out, grad_rhs_temp)
|
||||
grad_matrix = neg_op(grad_matrix)
|
||||
else:
|
||||
grad_matrix = mat_mul_op(out, grad_rhs_temp)
|
||||
grad_matrix = neg_op(grad_matrix)
|
||||
else:
|
||||
if len(matrix.shape) == 2:
|
||||
grad_matrix = mat_mul_2d_op(grad_rhs, out_temp)
|
||||
grad_matrix = neg_op(grad_matrix)
|
||||
else:
|
||||
grad_matrix = mat_mul_op(grad_rhs, out_temp)
|
||||
grad_matrix = neg_op(grad_matrix)
|
||||
if lower_a:
|
||||
if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
|
||||
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), -1, 0)
|
||||
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), -1, 0)
|
||||
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
|
||||
else:
|
||||
grad_matrix = matrix_band_part_op(grad_matrix, -1, 0)
|
||||
else:
|
||||
if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
|
||||
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), 0, -1)
|
||||
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), 0, -1)
|
||||
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
|
||||
else:
|
||||
grad_matrix = matrix_band_part_op(grad_matrix, 0, -1)
|
||||
return (grad_matrix, grad_rhs)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(MatrixExp)
|
||||
def get_bprop_matrix_exp(self):
|
||||
"""Gegerate brop for MatrixExp"""
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BatchMatMul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
batch_matmul_op_info = AiCPURegOp("BatchMatMul") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.attr("adj_x1", "bool") \
|
||||
.attr("adj_x2", "bool") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batch_matmul_op_info)
|
||||
def _batch_matmul_aicpu():
|
||||
"""BatchMatMul AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MatrixTriangularSolve op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
matrix_triangular_solve_op_info = AiCPURegOp("MatrixTriangularSolve") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("lower", "bool") \
|
||||
.attr("adjoint", "bool") \
|
||||
.input(0, "matrix", "required") \
|
||||
.input(1, "rhs", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(matrix_triangular_solve_op_info)
|
||||
def _matrix_triangular_solve_aicpu():
|
||||
"""MatrixTriangularSolve aicpu register"""
|
||||
return
|
|
@ -6996,6 +6996,51 @@ class TrilIndices(Primitive):
|
|||
validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||
|
||||
|
||||
class MatrixTriangularSolve(Primitive):
|
||||
r"""
|
||||
Returns a new tensor with the solotion of a linear equation system with an upper or lower triangular matrix.
|
||||
|
||||
Args:
|
||||
lower (bool): If true, the innermost matrices in `matrix` is are lower triangular. Default: True.
|
||||
adjoint (bool): If true, solve with the adjoint of `matrix`. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **matrix** (Tensor) - Tensor of shape :math:`(*, M, M)`,
|
||||
with float32, float64, complex64 and complex128 data type.
|
||||
- **rhs** (Tensor) - Tensor of shape :math:`(*, M, N)`,
|
||||
with float32, float64, complex64 and complex128 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the shape of math:`(*, M, N)` and the same data type as `matrix`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `matrix` or `rhs` is not a Tensor.
|
||||
TypeError: If `lower` or `adjoint` is not bool.
|
||||
ValueError: If the batch sizes of `matrix`and `rhs` are not equal.
|
||||
ValueError: If the inner-most 2 dimensions of `matrix` are not equal.
|
||||
ValueError: If the second-last dimensions of `matrix`and `rhs` are not equal.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> matrix_triangular_solve = ops.MatrixTriangularSolve(lower=True, adjoint=False)
|
||||
>>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
|
||||
>>> b = np.array([[1, 0],[2, 2],[1, 5],[0, 3]])
|
||||
>>> output = matrix_triangular_solve(Tensor(a, mindspore.float32), Tensor(b, mindspore.float32))
|
||||
>>> print(output)
|
||||
[[ 0.33333334 0. ]
|
||||
[ 1.3333333 2. ]
|
||||
[ 0.6666666 5. ]
|
||||
[-2.3333333 -4. ]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=True, adjoint=False):
|
||||
"""Initialize MatrixTriangularSolve"""
|
||||
validator.check_value_type('adjoint', adjoint, [bool], self.name)
|
||||
validator.check_value_type('lower', lower, [bool], self.name)
|
||||
|
||||
|
||||
class CompareAndBitpack(Primitive):
|
||||
"""
|
||||
Compare values of `x` to `threshold` and pack resulting bits into a `uint8`.
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops.operations._grad_ops import IgammaGradA
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.ops.operations.math_ops import Zeta, Igamma, Igammac
|
||||
from mindspore.ops.operations.math_ops import MatrixTriangularSolve
|
||||
from mindspore.ops.operations.sparse_ops import DenseToDenseSetOperation
|
||||
from mindspore.ops.operations.sparse_ops import DenseToSparseSetOperation
|
||||
|
||||
|
@ -894,6 +895,11 @@ raise_set = [
|
|||
'block': P.Trunc(),
|
||||
'desc_inputs': [Tensor(np.array([[1.1, 2.2, -4.1]], np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('MatrixTriangularSolve', {
|
||||
'block': MatrixTriangularSolve(adjoint=False, lower=True),
|
||||
'desc_inputs': [Tensor(np.array([4, 4, 4]).astype(np.float32)),
|
||||
Tensor(np.array([4, 4, 4]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([4, 4, 4]).astype(np.float32))]}),
|
||||
('Gcd', {
|
||||
'block': GcdFunc(),
|
||||
'desc_inputs': [Tensor(np.array([2, 5, 8]).astype(np.int32)),
|
||||
|
|
Loading…
Reference in New Issue