!29597 opt cholesky attributes.

Merge pull request !29597 from zhuzhongrui/pub_master2
This commit is contained in:
i-robot 2022-01-28 07:55:38 +00:00 committed by Gitee
commit 04d7f716e7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 44 additions and 38 deletions

View File

@ -34,7 +34,7 @@ constexpr size_t kColIndex = 1;
template <typename T>
void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid.";
MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid.";
}
if (shape.size() == kDefaultShape) {
*row = shape.front();
@ -62,8 +62,13 @@ void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
InitMatrixInfo(input_shape, &input_row_, &input_col_);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex);
InitMatrixInfo(output_shape, &output_row_, &output_col_);
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
clean_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CLEAN);
// if clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
if (AnfAlgo::HasNodeAttr(CLEAN, kernel_node)) {
clean_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CLEAN);
}
if (AnfAlgo::HasNodeAttr(LOWER, kernel_node)) {
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
}
}
template <typename T>

View File

@ -35,7 +35,7 @@ class CholeskyCpuKernelMod : public NativeCpuKernelMod {
private:
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
bool lower_{true};
bool clean_{false};
bool clean_{true};
size_t input_row_{1};
size_t input_col_{1};
size_t output_row_{1};

View File

@ -60,9 +60,17 @@ class CholeskyGpuKernelMod : public NativeGpuKernelMod {
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
clean_ = static_cast<bool>(GetAttr<bool>(kernel_node, kClean));
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, kSplitDim));
// if clean attribute exits, we will remain rand triangular data by clean flag, otherwise clean it to zero.
if (AnfAlgo::HasNodeAttr(kClean, kernel_node)) {
clean_ = static_cast<bool>(GetAttr<bool>(kernel_node, kClean));
}
if (AnfAlgo::HasNodeAttr(kLower, kernel_node)) {
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
}
if (AnfAlgo::HasNodeAttr(kSplitDim, kernel_node)) {
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, kSplitDim));
}
// cholesky input is sys_positive_matrix and saved by col_major in gpu backend.
// so we reverse lower to upper, to fake transpose col_major input to row_major.
if (lower_) {
@ -272,8 +280,8 @@ class CholeskyGpuKernelMod : public NativeGpuKernelMod {
cusolverDnHandle_t handle_{nullptr};
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
std::vector<pointer> h_array_;
bool lower_{false};
bool clean_{false};
bool lower_{true};
bool clean_{true};
};
} // namespace kernel
} // namespace mindspore

View File

@ -300,8 +300,10 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must have 2 dimensions.")
if a_shape[-1] != a_shape[-2]:
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must be a square matrix.")
cholesky_net = Cholesky(lower=lower, clean=False)
cholesky_net = Cholesky(clean=False)
c = cholesky_net(a)
if not lower:
c = c.T
return c, lower
@ -360,9 +362,10 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
if a_shape[-1] != a_shape[-2]:
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must be a square matrix.")
cholesky_net = Cholesky(lower=lower, clean=True)
cholesky_net = Cholesky(clean=True)
c = cholesky_net(a)
if not lower:
c = c.T
return c

View File

@ -106,35 +106,20 @@ class SolveTriangular(PrimitiveWithInfer):
class Cholesky(PrimitiveWithInfer):
"""
Inner API Cholesky Compute the Cholesky decomposition of a matrix.
clean is a special args for mindspore.scipy to indicate whether clean useless triangular matrix data.
"""
@prim_attr_register
def __init__(self, lower=False, clean=True, split_dim=0):
def __init__(self, clean=True):
super().__init__("Cholesky")
self.init_prim_io_names(inputs=['a'], outputs=['l'])
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
self.clean = validator.check_value_type("clean", clean, [bool], self.name)
self.lower = lower
self.add_prim_attr('lower', self.lower)
self.clean = clean
self.add_prim_attr('clean', self.clean)
self.split_dim = split_dim
self.add_prim_attr('split_dim', self.split_dim)
def __infer__(self, a):
a_shape = a['shape']
if self.split_dim != 0:
height = [0]
width = a_shape[1]
if height <= self.split_dim:
out_shape = [1, height, width]
else:
batch = height // self.split_dim
if height != batch * self.split_dim:
batch += 1
out_shape = [batch, self.split_dim, self.split_dim]
else:
out_shape = a_shape
out_shape = a_shape
output = {
'shape': tuple(out_shape),
'dtype': a['dtype'],

View File

@ -16,6 +16,7 @@
from .. import numpy as mnp
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
from .ops_wrapper import matrix_set_diag
from .utils_const import _raise_value_error
from .. import dtype as mstype
from ..ops import operations as P
from ..ops import functional as F
@ -53,17 +54,21 @@ def get_bprop_cholesky(self):
"""Grad definition for `Cholesky` operation."""
inverse = P.MatrixInverse()
matmul = P.MatMul()
clean = self.clean
if not clean:
_raise_value_error(
"primitive Cholesky not support attribute clean to be false, right now. please set it to be true.")
def bprop(a, out, dout):
l = out
l_inverse = inverse(l)
dout_middle = matmul(_adjoint(l), dout)
middle_diag = 0.5 * mnp.diag(dout_middle)
middle_diag = 0.5 * dout_middle.diagonal(0, -2, -1)
dout_middle = matrix_set_diag(dout_middle, middle_diag)
dout_middle = _matrix_band_part(dout_middle, -1, 0)
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
grad_a = mnp.tril(grad_a + _adjoint(grad_a))
middle_diag = 0.5 * mnp.diag(grad_a)
grad_a = _matrix_band_part(grad_a + _adjoint(grad_a), -1, 0)
middle_diag = 0.5 * grad_a.diagonal(0, -2, -1)
grad_a = matrix_set_diag(grad_a, middle_diag)
return (grad_a,)

View File

@ -31,9 +31,8 @@ np.random.seed(0)
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [3, 5, 7])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('dtype', [np.float64])
def test_cholesky(n: int, lower: bool, dtype: Generic):
def test_cholesky(n: int, dtype: Generic):
"""
Feature: ALL TO ALL
Description: test cases for cholesky [N,N]
@ -42,8 +41,8 @@ def test_cholesky(n: int, lower: bool, dtype: Generic):
context.set_context(mode=context.GRAPH_MODE)
a = create_sym_pos_matrix((n, n), dtype)
tensor_a = Tensor(a)
expect = scp.linalg.cholesky(a, lower=lower)
cholesky_net = Cholesky(lower=lower, clean=True)
expect = scp.linalg.cholesky(a, lower=True)
cholesky_net = Cholesky(clean=True)
output = cholesky_net(tensor_a)
assert np.allclose(expect, output.asnumpy())

View File

@ -41,7 +41,8 @@ def test_cholesky_grad(shape, data_type):
def __init__(self):
super(CholeskyNet, self).__init__()
self.mean = ops.ReduceMean()
self.cholesky = Cholesky(lower=True, clean=True)
# args clean not supports grad right now, just default to clean.
self.cholesky = Cholesky(clean=True)
def construct(self, a):
c = self.cholesky(a)