!30187 opt lu and change lu_solve implement to solve_triangular

Merge pull request !30187 from zhuzhongrui/pub_master2
This commit is contained in:
i-robot 2022-02-18 09:24:23 +00:00 committed by Gitee
commit 8010714c1c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 53 additions and 72 deletions

View File

@ -425,13 +425,13 @@ def arange(start, stop=None, step=None, dtype=None):
# infer the dtype
if dtype is None:
dtype = _get_dtype_from_scalar(start, stop, step)
if stop is None and step is None: # (start, stop, step) -> (0, start, 1)
if stop is None and step is None: # (start, stop, step) -> (0, start, 1)
num = _ceil(start)
out = _iota(mstype.float32, num)
elif step is None: # (start, stop, step) -> (start, stop, 1)
num = _ceil(stop - start)
out = _iota(mstype.float32, num) + start
elif stop is None: # (start, stop, step) -> (0, start, step)
elif stop is None: # (start, stop, step) -> (0, start, step)
num = _ceil((start + 0.0) / step)
out = _iota(mstype.float32, num) * step
else:
@ -466,8 +466,8 @@ def _compute_shapes(start, axis, num, endpoint):
"""Computes shapes for local variables for np.linspace"""
bounds_shape = start.shape
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
iota_shape = _list_comprehensions(start.ndim + 1, 1, True)
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis + 1, None)
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
div = (num_tensor - 1) if endpoint else num_tensor
return bounds_shape, iota_shape, div
@ -515,7 +515,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
"""
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
axis = _canonicalize_axis(axis, start.ndim+1)
axis = _canonicalize_axis(axis, start.ndim + 1)
if not isinstance(retstep, bool):
_raise_type_error("retstep should be an boolean, but got ", retstep)
bounds_shape, iota_shape, div = _compute_shapes(start, axis, num, endpoint)
@ -588,7 +588,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
"""
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
axis = _canonicalize_axis(axis, start.ndim+1)
axis = _canonicalize_axis(axis, start.ndim + 1)
if not isinstance(base, (int, float, bool)):
_raise_type_error("base should be a number, but got ", base)
linspace_res = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis)
@ -637,11 +637,11 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
[ 1. 2. 4. 8. 16. 32. 64. 128.]
"""
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
axis = _canonicalize_axis(axis, start.ndim+1)
axis = _canonicalize_axis(axis, start.ndim + 1)
root = num
if endpoint:
root -= 1
bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1./(root)))
bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1. / (root)))
exponents = linspace(zeros(F.shape(bases)), F.fill(F.dtype(bases), F.shape(bases), root),
num, endpoint=endpoint, dtype=dtype, axis=axis)
shape = F.shape(bases)
@ -703,11 +703,11 @@ def eye(N, M=None, k=0, dtype=mstype.float32):
out = out.astype(mstype.float32)
if k > 0:
out_left = full((N, k), 0, dtype)
out_right = out[..., 0:M-k:1]
out_right = out[..., 0:M - k:1]
return concatenate((out_left, out_right), 1).astype(dtype)
if k < 0:
out_upper = full((-k, M), 0, dtype)
out_lower = out[0:N+k:1, ...]
out_lower = out[0:N + k:1, ...]
return concatenate((out_upper, out_lower), 0).astype(dtype)
return out
@ -1041,8 +1041,7 @@ def tril(m, k=0):
if not isinstance(m, Tensor):
m = asarray_const(m)
dtype = m.dtype
m = m.astype(mstype.float32)
assist = nn_tril(m.shape, mstype.float32, k)
assist = nn_tril(m.shape, dtype, k)
return F.tensor_mul(assist, m).astype(dtype)
@ -1079,8 +1078,7 @@ def triu(m, k=0):
if not isinstance(m, Tensor):
m = asarray_const(m)
dtype = m.dtype
m = m.astype(mstype.float32)
assist = nn_triu(m.shape, mstype.float32, k)
assist = nn_triu(m.shape, dtype, k)
return F.tensor_mul(assist, m).astype(dtype)
@ -1321,6 +1319,7 @@ class NdGrid:
tensors are all of the same dimensions; and if ``sparse=True``,
returns tensors with only one dimension not equal to `1`.
"""
def __init__(self, sparse=False):
self.sparse = sparse
@ -1401,6 +1400,7 @@ class MGridClass(NdGrid):
>>> print(output)
[-1. -0.5 0. 0.5 1. ]
"""
def __init__(self):
super(MGridClass, self).__init__(sparse=False)
@ -1443,13 +1443,13 @@ class OGridClass(NdGrid):
>>> print(output)
[-1. -0.5 0. 0.5 1. ]
"""
def __init__(self):
super(OGridClass, self).__init__(sparse=True)
mgrid = MGridClass()
ogrid = OGridClass()
@ -1801,7 +1801,7 @@ def bartlett(M):
return ones(_max(0, M))
n = _iota(mstype.float32, M)
m_minus_one = _to_tensor(M - 1)
return _to_tensor(1) - F.absolute(_to_tensor(2)*n - m_minus_one)/m_minus_one
return _to_tensor(1) - F.absolute(_to_tensor(2) * n - m_minus_one) / m_minus_one
def blackman(M):
@ -1835,8 +1835,8 @@ def blackman(M):
if not _check_window_size(M):
return ones(_max(0, M))
n_doubled = arange(1 - M, M, 2, dtype=mstype.float32)
return (_to_tensor(0.42) + _to_tensor(0.5)*F.cos(_to_tensor(pi/(M - 1))*n_doubled) +
_to_tensor(0.08)*F.cos(_to_tensor(2*pi/(M - 1))*n_doubled))
return (_to_tensor(0.42) + _to_tensor(0.5) * F.cos(_to_tensor(pi / (M - 1)) * n_doubled) +
_to_tensor(0.08) * F.cos(_to_tensor(2 * pi / (M - 1)) * n_doubled))
def hamming(M):
@ -1867,7 +1867,7 @@ def hamming(M):
if not _check_window_size(M):
return ones(_max(0, M))
n = _iota(mstype.float32, M)
return _to_tensor(0.54) - _to_tensor(0.46)*F.cos(_to_tensor(2*pi/(M - 1))*n)
return _to_tensor(0.54) - _to_tensor(0.46) * F.cos(_to_tensor(2 * pi / (M - 1)) * n)
def hanning(M):
@ -1898,7 +1898,7 @@ def hanning(M):
if not _check_window_size(M):
return ones(_max(0, M))
n = _iota(mstype.float32, M)
return _to_tensor(0.5) - _to_tensor(0.5)*F.cos(_to_tensor(2*pi/(M - 1))*n)
return _to_tensor(0.5) - _to_tensor(0.5) * F.cos(_to_tensor(2 * pi / (M - 1)) * n)
@constexpr
@ -2050,7 +2050,7 @@ def tril_indices_from(arr, k=0):
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable=redefined-builtin
def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable=redefined-builtin
"""
Function to calculate only the edges of the bins used by the histogram function.
@ -2170,9 +2170,9 @@ def _pad_statistic(arr, pad_width, stat_length, stat_op):
stat_length = _limit_stat_length(stat_length, shape)
for i in range(ndim):
pad_before = stat_op(_slice_along_axis(arr, i, 0, stat_length[i][0]), i)
pad_before = (F.tile(pad_before, _tuple_setitem((1,)*ndim, i, pad_width[i][0])),)
pad_after = stat_op(_slice_along_axis(arr, i, shape[i]-stat_length[i][1], shape[i]), i)
pad_after = (F.tile(pad_after, _tuple_setitem((1,)*ndim, i, pad_width[i][1])),)
pad_before = (F.tile(pad_before, _tuple_setitem((1,) * ndim, i, pad_width[i][0])),)
pad_after = stat_op(_slice_along_axis(arr, i, shape[i] - stat_length[i][1], shape[i]), i)
pad_after = (F.tile(pad_after, _tuple_setitem((1,) * ndim, i, pad_width[i][1])),)
tensor_with_pad = pad_before + (arr,) + pad_after
arr = concatenate(tensor_with_pad, axis=i)
return arr
@ -2180,8 +2180,10 @@ def _pad_statistic(arr, pad_width, stat_length, stat_op):
def _pad_edge(arr, pad_width):
"""pad_edge is equivalent to pad_statistic with stat_lenght=1, used in mode:"edge"."""
def identity_op(arr, axis):
return arr
return _pad_statistic(arr, pad_width, 1, identity_op)
@ -2196,8 +2198,8 @@ def _pad_wrap(arr, pad_width):
tensor_with_pad = ()
# To avoid any memory issues, we don't make tensor with 0s in their shapes
if padsize_before > 0:
tensor_with_pad += (_slice_along_axis(arr, i, shape[i]-padsize_before, shape[i]),)
tensor_with_pad += (F.tile(arr, _tuple_setitem((1,)*ndim, i, total_repeats)),)
tensor_with_pad += (_slice_along_axis(arr, i, shape[i] - padsize_before, shape[i]),)
tensor_with_pad += (F.tile(arr, _tuple_setitem((1,) * ndim, i, total_repeats)),)
if padsize_after > 0:
tensor_with_pad += (_slice_along_axis(arr, i, 0, padsize_after),)
arr = concatenate(tensor_with_pad, axis=i)
@ -2212,16 +2214,16 @@ def _pad_linear(arr, pad_width, end_values):
end_values = _convert_pad_to_nd(end_values, ndim)
for i in range(ndim):
left_value = _slice_along_axis(arr, i, 0, 1)
right_value = _slice_along_axis(arr, i, shape[i]-1, shape[i])
right_value = _slice_along_axis(arr, i, shape[i] - 1, shape[i])
pad_before = ()
pad_after = ()
if pad_width[i][0] > 0:
pad_before = (linspace(end_values[i][0], left_value, num=pad_width[i][0],
endpoint=False, dtype=dtype, axis=i).squeeze(i+1),)
endpoint=False, dtype=dtype, axis=i).squeeze(i + 1),)
if pad_width[i][1] > 0:
pad_after = linspace(right_value, end_values[i][1], num=pad_width[i][1]+1,
endpoint=True, dtype=dtype, axis=i).squeeze(i+1)
pad_after = (_slice_along_axis(pad_after, i, 1, pad_width[i][1]+1),)
pad_after = linspace(right_value, end_values[i][1], num=pad_width[i][1] + 1,
endpoint=True, dtype=dtype, axis=i).squeeze(i + 1)
pad_after = (_slice_along_axis(pad_after, i, 1, pad_width[i][1] + 1),)
tensor_with_pad = pad_before + (arr,) + pad_after
arr = concatenate(tensor_with_pad, axis=i)
return arr
@ -2258,7 +2260,7 @@ def _add_pads_after(arr, pad_args, mode):
idx, array_length, times_to_pad_after, additional_pad_after, reflect_type = pad_args
curr_pad = None
endpoint_adder = None
edge_end = _slice_along_axis(arr, idx, arr.shape[idx]-1, arr.shape[idx])
edge_end = _slice_along_axis(arr, idx, arr.shape[idx] - 1, arr.shape[idx])
if mode == "reflect":
endpoint_adder = 1
else:
@ -2275,7 +2277,7 @@ def _add_pads_after(arr, pad_args, mode):
if reflect_type == "odd":
curr_pad = 2 * edge_end - curr_pad
arr = P.Concat(idx)((arr, curr_pad))
edge_end = _slice_along_axis(arr, idx, arr.shape[idx]-1, arr.shape[idx])
edge_end = _slice_along_axis(arr, idx, arr.shape[idx] - 1, arr.shape[idx])
return arr
@ -2311,7 +2313,7 @@ def _pad_reflect(arr, pad_width, reflect_type):
array_length = arr.shape[i]
if array_length == 1:
total_repeats = pad_width[i][0] + pad_width[i][1] + 1
arr = F.tile(arr, _tuple_setitem((1,)*arr.ndim, i, total_repeats))
arr = F.tile(arr, _tuple_setitem((1,) * arr.ndim, i, total_repeats))
else:
has_pad_before = (pad_width[i][0] > 0)
has_pad_after = (pad_width[i][1] > 0)
@ -2476,7 +2478,7 @@ def pad(arr, pad_width, mode="constant", stat_length=None, constant_values=0,
if mode not in ("constant", "maximum", "minimum", "mean", "median", "edge",
"wrap", "linear_ramp", "symmetric", "reflect", "empty") and \
not _callable(arr, mode):
not _callable(arr, mode):
_raise_value_error("Input mode not supported.")
if mode == "constant":

View File

@ -16,7 +16,6 @@
from .ops import Cholesky
from .ops import EighNet
from .ops import LU
from .ops import LUSolver
from .ops import SolveTriangular
from .utils import _nd_transpose
from .utils_const import _raise_value_error, _raise_type_error, _type_check
@ -561,7 +560,6 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
permutation = mnp.array(permutation)
if permutation_size == 0:
return permutation
for i in range(k):
j = pivots[..., i]
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
@ -572,29 +570,6 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
return permutation
def lu_solve_core(in_lu, permutation, b, trans):
""" core implementation of lu solve"""
m = in_lu.shape[0]
res_shape = b.shape[1:]
prod_result = 1
for sh in res_shape:
prod_result *= sh
x = mnp.reshape(b, (m, prod_result))
trans_str = None
if trans == 0:
trans_str = "N"
x = x[permutation, :]
elif trans == 1:
trans_str = "T"
elif trans == 2:
trans_str = "C"
else:
_raise_value_error("trans error, it's value must be 0, 1, 2")
ms_lu_solve = LUSolver(trans_str)
output = ms_lu_solve(in_lu, x)
return mnp.reshape(output, b.shape)
def check_lu_shape(in_lu, b):
""" check lu input shape"""
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
@ -612,8 +587,6 @@ def check_lu_shape(in_lu, b):
if b.shape[-2] != in_lu.shape[-1]:
_raise_value_error("LU decomposition: lu matrix and b must have same number of dimensions")
return True
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
@ -751,10 +724,10 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
_type_check('permute_l', permute_l, [bool], 'lu')
a_type = F.dtype(a)
if len(a.shape) < 2:
_raise_value_error("input matrix dimension of lu must larger than 2D.")
_raise_value_error("mindspore.scipy.linalg.lu input a's dimension must larger than 2D.")
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.lu only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"mindspore.scipy.linalg.lu input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float64)
@ -765,8 +738,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
if m > n:
_raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.")
k = min(m, n)
a_dtype = a.dtype
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_dtype)
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_type)
u = mnp.triu(m_lu)[:k, :]
if permute_l:
return mnp.dot(p, l), u
@ -819,22 +791,29 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
_type_check('trans', trans, [int], 'lu_solve')
m_lu, pivots = lu_and_piv
m_lu_type = F.dtype(m_lu)
if len(m_lu.shape) < 2:
_raise_value_error("input matrix dimension of lu_solve must larger than 2D.")
if m_lu_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.lu_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if m_lu_type not in (mstype.float32, mstype.float64):
m_lu = F.cast(m_lu, mstype.float64)
# 1. check shape
# 1. Check shape
check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
# 2. calculate permutation
# 2. Calculate permutation
permutation = lu_pivots_to_permutation(pivots, pivots.size)
# 3. rhs_vector
# 3. Get rhs_vector
rhs_vector = m_lu.ndim == b.ndim + 1
x = lu_solve_core(m_lu, permutation, b, trans)
x = b[permutation, :]
if trans == 0:
x = SolveTriangular(lower=True, unit_diagonal=True, trans='N')(m_lu, x)
x = SolveTriangular(lower=False, unit_diagonal=False, trans='N')(m_lu, x)
elif trans in (1, 2):
x = SolveTriangular(lower=False, unit_diagonal=False, trans='T')(m_lu, x)
x = SolveTriangular(lower=True, unit_diagonal=True, trans='T')(m_lu, x)
else:
_raise_value_error("mindspore.scipy.linalg.lu_solve input trans must be 0,1 or 2, but got ", trans)
x = mnp.reshape(x, b.shape)
return x[..., 0] if rhs_vector else x