forked from mindspore-Ecosystem/mindspore
!29155 fix cholesky doc api
Merge pull request !29155 from zhuzhongrui/pub_master
This commit is contained in:
commit
e0f149faf7
|
@ -16,31 +16,27 @@
|
|||
|
||||
#include "triangle_matrix_copy_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m) {
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
|
||||
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
if (col < row) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, bool clean, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m) {
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', if clean is false, the upper half and the positive diagonal of the matrix
|
||||
// should not be assigned any value, otherwise they should be assigned to 0.
|
||||
// If fill mode is 'CUBLAS_FILL_MODE_UPPER',if clean is false, the lower half and the positive diagonal of the matrix
|
||||
// should not be assigned any value, otherwise they should be assigned to 0.
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
if (uplo == CUBLAS_FILL_MODE_UPPER) {
|
||||
if (col > row && !clean) {
|
||||
output[i] = input[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = i / (ldb * m);
|
||||
size_t row = (i - batchIdx * ldb * m) / m;
|
||||
size_t col = (i - batchIdx * ldb * m) % m;
|
||||
if (col > row) {
|
||||
} else if (col > row && clean) {
|
||||
output[i] = 0;
|
||||
}
|
||||
} else {
|
||||
if (col < row && !clean) {
|
||||
output[i] = input[i];
|
||||
} else if (col < row && clean) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[row * m + col] = input[col * m + row];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -54,19 +50,21 @@ __global__ void MatrixCopyKernel(const T *input, T *output, const size_t count)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream) {
|
||||
TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
|
||||
void TriangleMatrixCopy(const T *input, T *output, bool clean, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream) {
|
||||
TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, clean, uplo, count, ldb,
|
||||
m);
|
||||
return;
|
||||
}
|
||||
|
||||
template void TriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void TriangleMatrixCopy<float>(const float *input, float *output, bool clean, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void TriangleMatrixCopy<half>(const half *input, half *output, bool clean, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template void TriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
template void TriangleMatrixCopy<double>(const double *input, double *output, bool clean, cublasFillMode_t uplo,
|
||||
const size_t count, const size_t ldb, const size_t m,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
|
||||
const size_t m, cudaStream_t cuda_stream);
|
||||
void TriangleMatrixCopy(const T *input, T *output, bool clean, cublasFillMode_t uplo, const size_t count,
|
||||
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -42,6 +42,9 @@ static constexpr char kAvgPoolingModeLowerCase[] = "avg";
|
|||
// Used by cholesky
|
||||
static constexpr char kLower[] = "lower";
|
||||
|
||||
// Used by cholesky
|
||||
static constexpr char kClean[] = "clean";
|
||||
|
||||
// Used by cholesky
|
||||
static constexpr char kSplitDim[] = "split_dim";
|
||||
|
||||
|
|
|
@ -53,9 +53,7 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(blas_handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cublasSetStream failed");
|
||||
"cholesky bind cusolverDnSetStream failed");
|
||||
if (!use_split_matrix_) {
|
||||
return NoSplitLaunch(inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
@ -66,16 +64,17 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, kLower));
|
||||
clean_ = static_cast<bool>(GetAttr<bool>(kernel_node, kClean));
|
||||
split_dim_ = static_cast<int>(GetAttr<int64_t>(kernel_node, kSplitDim));
|
||||
// cholesky input is sys_positive_matrix and saved by col_major in gpu backend.
|
||||
// so we reverse lower to upper, to fake transpose col_major input to row_major.
|
||||
if (lower_) {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
}
|
||||
// get CuSolver Dense matrix handler
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
// get Cublas handler
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
|
||||
|
||||
|
@ -189,11 +188,17 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
auto d_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
|
||||
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
|
||||
|
||||
// copy input data to output, cholesky inplace output in gpu backend.
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, input1_addr, batch_ * m_ * lda_ * unit_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cuda memcopy input to output Fail");
|
||||
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
h_array_[i] = input1_addr + i * lda_ * m_;
|
||||
h_array_[i] = output_addr + i * lda_ * m_;
|
||||
}
|
||||
|
||||
// copy input's addr to d_array_addr
|
||||
// copy output's addr to d_array_addr
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(pointer) * batch_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
|
@ -212,8 +217,8 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
|
||||
}
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// copy results from written input's matrix to output's matrix by up or lower flag.
|
||||
TriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
|
||||
// copy results from original input's matrix to output's matrix by up or lower flag.
|
||||
TriangleMatrixCopy(input1_addr, output_addr, clean_, uplo_, output_elements, ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -248,8 +253,9 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
|
||||
}
|
||||
|
||||
TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo_, outputs[0]->size / sizeof(T), ldb_, m_,
|
||||
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
|
||||
// copy results from original input's matrix to output's matrix by up or lower flag.
|
||||
TriangleMatrixCopy(input1_addr, output_addr, clean_, uplo_, output_elements, ldb_, m_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -267,10 +273,10 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
bool is_null_input_{false};
|
||||
bool use_split_matrix_{false};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
std::vector<pointer> h_array_;
|
||||
bool lower_{false};
|
||||
bool clean_{false};
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
|
|
@ -186,7 +186,7 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
|
||||
Args:
|
||||
a (Tensor): Square matrix to be inverted. Note that if the input tensor is not a `float`,
|
||||
then it will be casted to :class:`mstype.float32`.
|
||||
then it will be cast to :class:`mstype.float32`.
|
||||
overwrite_a (bool, optional): Discard data in `a` (may improve performance). Default: False.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
|
||||
|
@ -225,7 +225,7 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute the Cholesky decomposition of a matrix, to use in cho_solve
|
||||
Note that if the input tensor is not a `float`, then it will be casted to :class:'mstype.float32'.
|
||||
Note that if the input tensor is not a `float`, then it will be cast to :class:'mstype.float32'.
|
||||
Returns a matrix containing the Cholesky decomposition,
|
||||
``A = L L*`` or ``A = U* U`` of a Hermitian positive-definite matrix `a`.
|
||||
The return value can be directly used as the first parameter to cho_solve.
|
||||
|
@ -236,14 +236,16 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
entries, use the function `cholesky` instead.
|
||||
|
||||
Args:
|
||||
a (Tensor): square Matrix of (M, M) to be decomposed. Note that if the input tensor is not a `float`,
|
||||
then it will be casted to :class:'mstype.float32'.
|
||||
a (Tensor): square Matrix of (M, M) to be decomposed. Note that if the input tensor is not a `float`
|
||||
or a `double`, then it will be cast to :class:'mstype.float64'.
|
||||
lower (bool, optional): Whether to compute the upper or lower triangular Cholesky factorization
|
||||
(Default: upper-triangular)
|
||||
overwrite_a(bool, optional): Whether to overwrite data in a (may improve performance)
|
||||
(Default: upper-triangular (false))
|
||||
overwrite_a(bool, optional): Whether to overwrite data in a (may improve performance). Default is False.
|
||||
in mindspore, this arg does not work right now.
|
||||
check_finite(bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default is True.
|
||||
in mindspore, this arg does not work right now.
|
||||
|
||||
Returns:
|
||||
- Tensor, matrix whose upper or lower triangle contains the Cholesky factor of `a`.
|
||||
|
@ -251,7 +253,7 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
- bool, flag indicating whether the factor is in the lower or upper triangle
|
||||
|
||||
Raises:
|
||||
LinAlgError: Raised if decomposition fails.
|
||||
ValueError: If input a tensor is not a square matrix or it's dims not equal to 2D.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
@ -269,7 +271,12 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[ 5. 1. 2. 1.5541857 ]]
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
a = F.cast(a, mstype.float64)
|
||||
a_shape = a.shape
|
||||
if len(a_shape) != 2:
|
||||
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must have 2 dimensions.")
|
||||
if a_shape[-1] != a_shape[-2]:
|
||||
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must be a square matrix.")
|
||||
cholesky_net = Cholesky(lower=lower, clean=False)
|
||||
c = cholesky_net(a)
|
||||
return c, lower
|
||||
|
@ -283,19 +290,22 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
:math:`A = U^* U` of a Hermitian positive-definite matrix A.
|
||||
|
||||
Args:
|
||||
a (Tensor): square Matrix of (M, M) to be decomposed
|
||||
a (Tensor): square Matrix of (M, M) to be decomposed, Note that if the input tensor is not a `float`
|
||||
or `double`, then it will be casted to :class:'mstype.float64'.
|
||||
lower (bool, optional): Whether to compute the upper- or lower-triangular Cholesky
|
||||
factorization. Default is upper-triangular.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance).
|
||||
factorization. Default is upper-triangular, which means lower defaults to false.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance). Default is False.
|
||||
in mindspore, this arg does not work right now.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
Default is True. Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
in mindspore, this arg does not work right now.
|
||||
|
||||
Returns:
|
||||
Tensor, upper- or lower-triangular Cholesky factor of `a`.
|
||||
|
||||
Raises:
|
||||
LinAlgError: if decomposition fails.
|
||||
ValueError: If input a tensor is not a square matrix or it's dims not equal to 2D.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
@ -304,14 +314,22 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import cholesky
|
||||
>>> a = Tensor(onp.array([[1, -2],[2, 5]]).astype(onp.float32))
|
||||
>>> a = Tensor(onp.array([[1, 2],[2, 5]]).astype(onp.float32))
|
||||
>>> L = cholesky(a, lower=True)
|
||||
>>> print(L)
|
||||
[[1. 0.]
|
||||
[2. 1.]]
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
a = F.cast(a, mstype.float64)
|
||||
|
||||
a_shape = a.shape
|
||||
if len(a_shape) != 2:
|
||||
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must have 2 dimensions.")
|
||||
|
||||
if a_shape[-1] != a_shape[-2]:
|
||||
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must be a square matrix.")
|
||||
|
||||
cholesky_net = Cholesky(lower=lower, clean=True)
|
||||
c = cholesky_net(a)
|
||||
return c
|
||||
|
@ -506,7 +524,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|||
|
||||
Args:
|
||||
a (Tensor): square matrix of :math:`(M, M)` to decompose. Note that if the input tensor is not a `float`,
|
||||
then it will be casted to :class:'mstype.float32'.
|
||||
then it will be cast to :class:'mstype.float32'.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in :math:`A` (may increase performance). Default: False.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
|
@ -565,7 +583,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
|
||||
Args:
|
||||
a (Tensor): a :math:`(M, N)` matrix to decompose. Note that if the input tensor is not a `float`,
|
||||
then it will be casted to :class:'mstype.float32'.
|
||||
then it will be cast to :class:'mstype.float32'.
|
||||
permute_l (bool, optional): Perform the multiplication :math:`P L` (Default: do not permute). Default: False.
|
||||
overwrite_a (bool, optional): Whether to overwrite data in :math:`A` (may improve performance). Default: False.
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains
|
||||
|
@ -705,7 +723,7 @@ def det(a, overwrite_a=False, check_finite=True):
|
|||
|
||||
Args:
|
||||
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
|
||||
then it will be casted to :class:`mstype.float32`.
|
||||
then it will be cast to :class:`mstype.float32`.
|
||||
overwrite_a (bool, optional): Allow overwriting data in a (may enhance performance).
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains
|
||||
only finite numbers.
|
||||
|
|
|
@ -105,13 +105,13 @@ class SolveTriangular(PrimitiveWithInfer):
|
|||
|
||||
class Cholesky(PrimitiveWithInfer):
|
||||
"""
|
||||
Cholesky decomposition for A.
|
||||
Inner API Cholesky Compute the Cholesky decomposition of a matrix.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False, clean=True, split_dim=0):
|
||||
super().__init__("Cholesky")
|
||||
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
|
||||
self.init_prim_io_names(inputs=['a'], outputs=['l'])
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.lower = lower
|
||||
|
@ -121,10 +121,11 @@ class Cholesky(PrimitiveWithInfer):
|
|||
self.split_dim = split_dim
|
||||
self.add_prim_attr('split_dim', self.split_dim)
|
||||
|
||||
def infer_shape(self, x1_shape):
|
||||
def __infer__(self, a):
|
||||
a_shape = a['shape']
|
||||
if self.split_dim != 0:
|
||||
height = x1_shape[0]
|
||||
width = x1_shape[1]
|
||||
height = [0]
|
||||
width = a_shape[1]
|
||||
if height <= self.split_dim:
|
||||
out_shape = [1, height, width]
|
||||
else:
|
||||
|
@ -133,12 +134,13 @@ class Cholesky(PrimitiveWithInfer):
|
|||
batch += 1
|
||||
out_shape = [batch, self.split_dim, self.split_dim]
|
||||
else:
|
||||
out_shape = x1_shape
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x1_dtype):
|
||||
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32, mstype.float64], self.name)
|
||||
return x1_dtype
|
||||
out_shape = a_shape
|
||||
output = {
|
||||
'shape': tuple(out_shape),
|
||||
'dtype': a['dtype'],
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class CholeskySolver(PrimitiveWithInfer):
|
||||
|
|
|
@ -21,7 +21,8 @@ from .utils_const import _raise_value_error
|
|||
|
||||
def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"):
|
||||
"""
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
Calculate a batched matrix tensor with new batched diagonal values.
|
||||
|
||||
Given `input` and `diagonal`, this operation returns a tensor with the same shape and values as `input`,
|
||||
except for the specified diagonals of the innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or `k[0] == k[1]`,
|
||||
|
|
|
@ -19,10 +19,8 @@ from ..numpy import where, zeros_like, dot, greater
|
|||
from ..ops import functional as F
|
||||
from ..common import Tensor
|
||||
from ..common import dtype as mstype
|
||||
from .utils_const import _type_convert, _raise_value_error
|
||||
from .utils_const import _type_convert, _raise_value_error, _callable_const
|
||||
from ..ops.composite import GradOperation
|
||||
from ..ops.primitive import constexpr
|
||||
from .._c_expression import typing
|
||||
|
||||
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
||||
_eps_net = ops.Eps()
|
||||
|
@ -106,12 +104,6 @@ _BOOL_FALSE = _to_tensor(False)
|
|||
float_types = (mstype.float32, mstype.float64)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _callable_const(x):
|
||||
"""Returns true if x is a function in graph mode."""
|
||||
return isinstance(x, typing.Function)
|
||||
|
||||
|
||||
def _normalize_matvec(f):
|
||||
"""Normalize an argument for computing matrix-vector products."""
|
||||
if _callable_const(F.typeof(f)):
|
||||
|
|
|
@ -14,6 +14,13 @@
|
|||
# ============================================================================
|
||||
"""internal graph-compatible utility functions"""
|
||||
from ..ops.primitive import constexpr
|
||||
from .._c_expression import typing
|
||||
|
||||
|
||||
@constexpr
|
||||
def _callable_const(x):
|
||||
"""Returns true if x is a function in graph mode."""
|
||||
return isinstance(x, typing.Function)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -76,7 +76,7 @@ def test_inv(data_type, shape):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('data_type', [onp.float64])
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_cholesky(n: int, lower: bool, data_type: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
|
@ -85,8 +85,11 @@ def test_cholesky(n: int, lower: bool, data_type: Generic):
|
|||
"""
|
||||
a = create_sym_pos_matrix((n, n), data_type)
|
||||
tensor_a = Tensor(a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-8
|
||||
rtol = 1.e-3
|
||||
atol = 1.e-3
|
||||
if data_type == onp.float64:
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-8
|
||||
osp_c = osp.linalg.cholesky(a, lower=lower)
|
||||
msp_c = msp.linalg.cholesky(tensor_a, lower=lower)
|
||||
assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol)
|
||||
|
@ -98,21 +101,23 @@ def test_cholesky(n: int, lower: bool, data_type: Generic):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('lower', [True, False])
|
||||
@pytest.mark.parametrize('data_type', [onp.float64])
|
||||
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
|
||||
def test_cho_factor(n: int, lower: bool, data_type: Generic):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cholesky [N,N]
|
||||
Description: test cases for cho_factor [N,N]
|
||||
Expectation: the result match scipy cholesky
|
||||
"""
|
||||
a = create_sym_pos_matrix((n, n), data_type)
|
||||
tensor_a = Tensor(a)
|
||||
msp_c, _ = msp.linalg.cho_factor(tensor_a, lower=lower)
|
||||
if lower:
|
||||
msp_reconstruct_a = mnp.dot(mnp.tril(msp_c), mnp.tril(msp_c).T)
|
||||
else:
|
||||
msp_reconstruct_a = mnp.dot(mnp.triu(msp_c).T, mnp.triu(msp_c))
|
||||
assert onp.allclose(a, msp_reconstruct_a.asnumpy())
|
||||
osp_c, _ = osp.linalg.cho_factor(a, lower=lower)
|
||||
rtol = 1.e-3
|
||||
atol = 1.e-3
|
||||
if data_type == onp.float64:
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-8
|
||||
assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue