forked from mindspore-Ecosystem/mindspore
!30187 opt lu and change lu_solve implement to solve_triangular
Merge pull request !30187 from zhuzhongrui/pub_master2
This commit is contained in:
commit
8010714c1c
|
@ -425,13 +425,13 @@ def arange(start, stop=None, step=None, dtype=None):
|
|||
# infer the dtype
|
||||
if dtype is None:
|
||||
dtype = _get_dtype_from_scalar(start, stop, step)
|
||||
if stop is None and step is None: # (start, stop, step) -> (0, start, 1)
|
||||
if stop is None and step is None: # (start, stop, step) -> (0, start, 1)
|
||||
num = _ceil(start)
|
||||
out = _iota(mstype.float32, num)
|
||||
elif step is None: # (start, stop, step) -> (start, stop, 1)
|
||||
num = _ceil(stop - start)
|
||||
out = _iota(mstype.float32, num) + start
|
||||
elif stop is None: # (start, stop, step) -> (0, start, step)
|
||||
elif stop is None: # (start, stop, step) -> (0, start, step)
|
||||
num = _ceil((start + 0.0) / step)
|
||||
out = _iota(mstype.float32, num) * step
|
||||
else:
|
||||
|
@ -466,8 +466,8 @@ def _compute_shapes(start, axis, num, endpoint):
|
|||
"""Computes shapes for local variables for np.linspace"""
|
||||
bounds_shape = start.shape
|
||||
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
|
||||
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
|
||||
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
|
||||
iota_shape = _list_comprehensions(start.ndim + 1, 1, True)
|
||||
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis + 1, None)
|
||||
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
|
||||
div = (num_tensor - 1) if endpoint else num_tensor
|
||||
return bounds_shape, iota_shape, div
|
||||
|
@ -515,7 +515,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
|||
"""
|
||||
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
|
||||
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
|
||||
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||
axis = _canonicalize_axis(axis, start.ndim + 1)
|
||||
if not isinstance(retstep, bool):
|
||||
_raise_type_error("retstep should be an boolean, but got ", retstep)
|
||||
bounds_shape, iota_shape, div = _compute_shapes(start, axis, num, endpoint)
|
||||
|
@ -588,7 +588,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
|
|||
"""
|
||||
# This implementation was inspired by jax.numpy.linspace and numpy.linspace
|
||||
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
|
||||
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||
axis = _canonicalize_axis(axis, start.ndim + 1)
|
||||
if not isinstance(base, (int, float, bool)):
|
||||
_raise_type_error("base should be a number, but got ", base)
|
||||
linspace_res = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis)
|
||||
|
@ -637,11 +637,11 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
|
|||
[ 1. 2. 4. 8. 16. 32. 64. 128.]
|
||||
"""
|
||||
start, stop, num, endpoint, dtype = _type_checking_for_xspace(start, stop, num, endpoint, dtype)
|
||||
axis = _canonicalize_axis(axis, start.ndim+1)
|
||||
axis = _canonicalize_axis(axis, start.ndim + 1)
|
||||
root = num
|
||||
if endpoint:
|
||||
root -= 1
|
||||
bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1./(root)))
|
||||
bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1. / (root)))
|
||||
exponents = linspace(zeros(F.shape(bases)), F.fill(F.dtype(bases), F.shape(bases), root),
|
||||
num, endpoint=endpoint, dtype=dtype, axis=axis)
|
||||
shape = F.shape(bases)
|
||||
|
@ -703,11 +703,11 @@ def eye(N, M=None, k=0, dtype=mstype.float32):
|
|||
out = out.astype(mstype.float32)
|
||||
if k > 0:
|
||||
out_left = full((N, k), 0, dtype)
|
||||
out_right = out[..., 0:M-k:1]
|
||||
out_right = out[..., 0:M - k:1]
|
||||
return concatenate((out_left, out_right), 1).astype(dtype)
|
||||
if k < 0:
|
||||
out_upper = full((-k, M), 0, dtype)
|
||||
out_lower = out[0:N+k:1, ...]
|
||||
out_lower = out[0:N + k:1, ...]
|
||||
return concatenate((out_upper, out_lower), 0).astype(dtype)
|
||||
return out
|
||||
|
||||
|
@ -1041,8 +1041,7 @@ def tril(m, k=0):
|
|||
if not isinstance(m, Tensor):
|
||||
m = asarray_const(m)
|
||||
dtype = m.dtype
|
||||
m = m.astype(mstype.float32)
|
||||
assist = nn_tril(m.shape, mstype.float32, k)
|
||||
assist = nn_tril(m.shape, dtype, k)
|
||||
return F.tensor_mul(assist, m).astype(dtype)
|
||||
|
||||
|
||||
|
@ -1079,8 +1078,7 @@ def triu(m, k=0):
|
|||
if not isinstance(m, Tensor):
|
||||
m = asarray_const(m)
|
||||
dtype = m.dtype
|
||||
m = m.astype(mstype.float32)
|
||||
assist = nn_triu(m.shape, mstype.float32, k)
|
||||
assist = nn_triu(m.shape, dtype, k)
|
||||
return F.tensor_mul(assist, m).astype(dtype)
|
||||
|
||||
|
||||
|
@ -1321,6 +1319,7 @@ class NdGrid:
|
|||
tensors are all of the same dimensions; and if ``sparse=True``,
|
||||
returns tensors with only one dimension not equal to `1`.
|
||||
"""
|
||||
|
||||
def __init__(self, sparse=False):
|
||||
self.sparse = sparse
|
||||
|
||||
|
@ -1401,6 +1400,7 @@ class MGridClass(NdGrid):
|
|||
>>> print(output)
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MGridClass, self).__init__(sparse=False)
|
||||
|
||||
|
@ -1443,13 +1443,13 @@ class OGridClass(NdGrid):
|
|||
>>> print(output)
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(OGridClass, self).__init__(sparse=True)
|
||||
|
||||
|
||||
mgrid = MGridClass()
|
||||
|
||||
|
||||
ogrid = OGridClass()
|
||||
|
||||
|
||||
|
@ -1801,7 +1801,7 @@ def bartlett(M):
|
|||
return ones(_max(0, M))
|
||||
n = _iota(mstype.float32, M)
|
||||
m_minus_one = _to_tensor(M - 1)
|
||||
return _to_tensor(1) - F.absolute(_to_tensor(2)*n - m_minus_one)/m_minus_one
|
||||
return _to_tensor(1) - F.absolute(_to_tensor(2) * n - m_minus_one) / m_minus_one
|
||||
|
||||
|
||||
def blackman(M):
|
||||
|
@ -1835,8 +1835,8 @@ def blackman(M):
|
|||
if not _check_window_size(M):
|
||||
return ones(_max(0, M))
|
||||
n_doubled = arange(1 - M, M, 2, dtype=mstype.float32)
|
||||
return (_to_tensor(0.42) + _to_tensor(0.5)*F.cos(_to_tensor(pi/(M - 1))*n_doubled) +
|
||||
_to_tensor(0.08)*F.cos(_to_tensor(2*pi/(M - 1))*n_doubled))
|
||||
return (_to_tensor(0.42) + _to_tensor(0.5) * F.cos(_to_tensor(pi / (M - 1)) * n_doubled) +
|
||||
_to_tensor(0.08) * F.cos(_to_tensor(2 * pi / (M - 1)) * n_doubled))
|
||||
|
||||
|
||||
def hamming(M):
|
||||
|
@ -1867,7 +1867,7 @@ def hamming(M):
|
|||
if not _check_window_size(M):
|
||||
return ones(_max(0, M))
|
||||
n = _iota(mstype.float32, M)
|
||||
return _to_tensor(0.54) - _to_tensor(0.46)*F.cos(_to_tensor(2*pi/(M - 1))*n)
|
||||
return _to_tensor(0.54) - _to_tensor(0.46) * F.cos(_to_tensor(2 * pi / (M - 1)) * n)
|
||||
|
||||
|
||||
def hanning(M):
|
||||
|
@ -1898,7 +1898,7 @@ def hanning(M):
|
|||
if not _check_window_size(M):
|
||||
return ones(_max(0, M))
|
||||
n = _iota(mstype.float32, M)
|
||||
return _to_tensor(0.5) - _to_tensor(0.5)*F.cos(_to_tensor(2*pi/(M - 1))*n)
|
||||
return _to_tensor(0.5) - _to_tensor(0.5) * F.cos(_to_tensor(2 * pi / (M - 1)) * n)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -2050,7 +2050,7 @@ def tril_indices_from(arr, k=0):
|
|||
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
|
||||
|
||||
|
||||
def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable=redefined-builtin
|
||||
def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable=redefined-builtin
|
||||
"""
|
||||
Function to calculate only the edges of the bins used by the histogram function.
|
||||
|
||||
|
@ -2170,9 +2170,9 @@ def _pad_statistic(arr, pad_width, stat_length, stat_op):
|
|||
stat_length = _limit_stat_length(stat_length, shape)
|
||||
for i in range(ndim):
|
||||
pad_before = stat_op(_slice_along_axis(arr, i, 0, stat_length[i][0]), i)
|
||||
pad_before = (F.tile(pad_before, _tuple_setitem((1,)*ndim, i, pad_width[i][0])),)
|
||||
pad_after = stat_op(_slice_along_axis(arr, i, shape[i]-stat_length[i][1], shape[i]), i)
|
||||
pad_after = (F.tile(pad_after, _tuple_setitem((1,)*ndim, i, pad_width[i][1])),)
|
||||
pad_before = (F.tile(pad_before, _tuple_setitem((1,) * ndim, i, pad_width[i][0])),)
|
||||
pad_after = stat_op(_slice_along_axis(arr, i, shape[i] - stat_length[i][1], shape[i]), i)
|
||||
pad_after = (F.tile(pad_after, _tuple_setitem((1,) * ndim, i, pad_width[i][1])),)
|
||||
tensor_with_pad = pad_before + (arr,) + pad_after
|
||||
arr = concatenate(tensor_with_pad, axis=i)
|
||||
return arr
|
||||
|
@ -2180,8 +2180,10 @@ def _pad_statistic(arr, pad_width, stat_length, stat_op):
|
|||
|
||||
def _pad_edge(arr, pad_width):
|
||||
"""pad_edge is equivalent to pad_statistic with stat_lenght=1, used in mode:"edge"."""
|
||||
|
||||
def identity_op(arr, axis):
|
||||
return arr
|
||||
|
||||
return _pad_statistic(arr, pad_width, 1, identity_op)
|
||||
|
||||
|
||||
|
@ -2196,8 +2198,8 @@ def _pad_wrap(arr, pad_width):
|
|||
tensor_with_pad = ()
|
||||
# To avoid any memory issues, we don't make tensor with 0s in their shapes
|
||||
if padsize_before > 0:
|
||||
tensor_with_pad += (_slice_along_axis(arr, i, shape[i]-padsize_before, shape[i]),)
|
||||
tensor_with_pad += (F.tile(arr, _tuple_setitem((1,)*ndim, i, total_repeats)),)
|
||||
tensor_with_pad += (_slice_along_axis(arr, i, shape[i] - padsize_before, shape[i]),)
|
||||
tensor_with_pad += (F.tile(arr, _tuple_setitem((1,) * ndim, i, total_repeats)),)
|
||||
if padsize_after > 0:
|
||||
tensor_with_pad += (_slice_along_axis(arr, i, 0, padsize_after),)
|
||||
arr = concatenate(tensor_with_pad, axis=i)
|
||||
|
@ -2212,16 +2214,16 @@ def _pad_linear(arr, pad_width, end_values):
|
|||
end_values = _convert_pad_to_nd(end_values, ndim)
|
||||
for i in range(ndim):
|
||||
left_value = _slice_along_axis(arr, i, 0, 1)
|
||||
right_value = _slice_along_axis(arr, i, shape[i]-1, shape[i])
|
||||
right_value = _slice_along_axis(arr, i, shape[i] - 1, shape[i])
|
||||
pad_before = ()
|
||||
pad_after = ()
|
||||
if pad_width[i][0] > 0:
|
||||
pad_before = (linspace(end_values[i][0], left_value, num=pad_width[i][0],
|
||||
endpoint=False, dtype=dtype, axis=i).squeeze(i+1),)
|
||||
endpoint=False, dtype=dtype, axis=i).squeeze(i + 1),)
|
||||
if pad_width[i][1] > 0:
|
||||
pad_after = linspace(right_value, end_values[i][1], num=pad_width[i][1]+1,
|
||||
endpoint=True, dtype=dtype, axis=i).squeeze(i+1)
|
||||
pad_after = (_slice_along_axis(pad_after, i, 1, pad_width[i][1]+1),)
|
||||
pad_after = linspace(right_value, end_values[i][1], num=pad_width[i][1] + 1,
|
||||
endpoint=True, dtype=dtype, axis=i).squeeze(i + 1)
|
||||
pad_after = (_slice_along_axis(pad_after, i, 1, pad_width[i][1] + 1),)
|
||||
tensor_with_pad = pad_before + (arr,) + pad_after
|
||||
arr = concatenate(tensor_with_pad, axis=i)
|
||||
return arr
|
||||
|
@ -2258,7 +2260,7 @@ def _add_pads_after(arr, pad_args, mode):
|
|||
idx, array_length, times_to_pad_after, additional_pad_after, reflect_type = pad_args
|
||||
curr_pad = None
|
||||
endpoint_adder = None
|
||||
edge_end = _slice_along_axis(arr, idx, arr.shape[idx]-1, arr.shape[idx])
|
||||
edge_end = _slice_along_axis(arr, idx, arr.shape[idx] - 1, arr.shape[idx])
|
||||
if mode == "reflect":
|
||||
endpoint_adder = 1
|
||||
else:
|
||||
|
@ -2275,7 +2277,7 @@ def _add_pads_after(arr, pad_args, mode):
|
|||
if reflect_type == "odd":
|
||||
curr_pad = 2 * edge_end - curr_pad
|
||||
arr = P.Concat(idx)((arr, curr_pad))
|
||||
edge_end = _slice_along_axis(arr, idx, arr.shape[idx]-1, arr.shape[idx])
|
||||
edge_end = _slice_along_axis(arr, idx, arr.shape[idx] - 1, arr.shape[idx])
|
||||
return arr
|
||||
|
||||
|
||||
|
@ -2311,7 +2313,7 @@ def _pad_reflect(arr, pad_width, reflect_type):
|
|||
array_length = arr.shape[i]
|
||||
if array_length == 1:
|
||||
total_repeats = pad_width[i][0] + pad_width[i][1] + 1
|
||||
arr = F.tile(arr, _tuple_setitem((1,)*arr.ndim, i, total_repeats))
|
||||
arr = F.tile(arr, _tuple_setitem((1,) * arr.ndim, i, total_repeats))
|
||||
else:
|
||||
has_pad_before = (pad_width[i][0] > 0)
|
||||
has_pad_after = (pad_width[i][1] > 0)
|
||||
|
@ -2476,7 +2478,7 @@ def pad(arr, pad_width, mode="constant", stat_length=None, constant_values=0,
|
|||
|
||||
if mode not in ("constant", "maximum", "minimum", "mean", "median", "edge",
|
||||
"wrap", "linear_ramp", "symmetric", "reflect", "empty") and \
|
||||
not _callable(arr, mode):
|
||||
not _callable(arr, mode):
|
||||
_raise_value_error("Input mode not supported.")
|
||||
|
||||
if mode == "constant":
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
from .ops import Cholesky
|
||||
from .ops import EighNet
|
||||
from .ops import LU
|
||||
from .ops import LUSolver
|
||||
from .ops import SolveTriangular
|
||||
from .utils import _nd_transpose
|
||||
from .utils_const import _raise_value_error, _raise_type_error, _type_check
|
||||
|
@ -561,7 +560,6 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
|
|||
permutation = mnp.array(permutation)
|
||||
if permutation_size == 0:
|
||||
return permutation
|
||||
|
||||
for i in range(k):
|
||||
j = pivots[..., i]
|
||||
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
|
||||
|
@ -572,29 +570,6 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
|
|||
return permutation
|
||||
|
||||
|
||||
def lu_solve_core(in_lu, permutation, b, trans):
|
||||
""" core implementation of lu solve"""
|
||||
m = in_lu.shape[0]
|
||||
res_shape = b.shape[1:]
|
||||
prod_result = 1
|
||||
for sh in res_shape:
|
||||
prod_result *= sh
|
||||
x = mnp.reshape(b, (m, prod_result))
|
||||
trans_str = None
|
||||
if trans == 0:
|
||||
trans_str = "N"
|
||||
x = x[permutation, :]
|
||||
elif trans == 1:
|
||||
trans_str = "T"
|
||||
elif trans == 2:
|
||||
trans_str = "C"
|
||||
else:
|
||||
_raise_value_error("trans error, it's value must be 0, 1, 2")
|
||||
ms_lu_solve = LUSolver(trans_str)
|
||||
output = ms_lu_solve(in_lu, x)
|
||||
return mnp.reshape(output, b.shape)
|
||||
|
||||
|
||||
def check_lu_shape(in_lu, b):
|
||||
""" check lu input shape"""
|
||||
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
|
||||
|
@ -612,8 +587,6 @@ def check_lu_shape(in_lu, b):
|
|||
if b.shape[-2] != in_lu.shape[-1]:
|
||||
_raise_value_error("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
|
@ -751,10 +724,10 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
_type_check('permute_l', permute_l, [bool], 'lu')
|
||||
a_type = F.dtype(a)
|
||||
if len(a.shape) < 2:
|
||||
_raise_value_error("input matrix dimension of lu must larger than 2D.")
|
||||
_raise_value_error("mindspore.scipy.linalg.lu input a's dimension must larger than 2D.")
|
||||
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"mindspore.scipy.linalg.lu input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if a_type not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float64)
|
||||
|
@ -765,8 +738,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
if m > n:
|
||||
_raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.")
|
||||
k = min(m, n)
|
||||
a_dtype = a.dtype
|
||||
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_dtype)
|
||||
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_type)
|
||||
u = mnp.triu(m_lu)[:k, :]
|
||||
if permute_l:
|
||||
return mnp.dot(p, l), u
|
||||
|
@ -819,22 +791,29 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|||
_type_check('trans', trans, [int], 'lu_solve')
|
||||
m_lu, pivots = lu_and_piv
|
||||
m_lu_type = F.dtype(m_lu)
|
||||
if len(m_lu.shape) < 2:
|
||||
_raise_value_error("input matrix dimension of lu_solve must larger than 2D.")
|
||||
if m_lu_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
|
||||
_raise_type_error(
|
||||
"mindspore.scipy.linalg.lu_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
|
||||
"Tensor[float64]).")
|
||||
if m_lu_type not in (mstype.float32, mstype.float64):
|
||||
m_lu = F.cast(m_lu, mstype.float64)
|
||||
# 1. check shape
|
||||
# 1. Check shape
|
||||
check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
# 2. calculate permutation
|
||||
# 2. Calculate permutation
|
||||
permutation = lu_pivots_to_permutation(pivots, pivots.size)
|
||||
# 3. rhs_vector
|
||||
# 3. Get rhs_vector
|
||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
x = lu_solve_core(m_lu, permutation, b, trans)
|
||||
x = b[permutation, :]
|
||||
if trans == 0:
|
||||
x = SolveTriangular(lower=True, unit_diagonal=True, trans='N')(m_lu, x)
|
||||
x = SolveTriangular(lower=False, unit_diagonal=False, trans='N')(m_lu, x)
|
||||
elif trans in (1, 2):
|
||||
x = SolveTriangular(lower=False, unit_diagonal=False, trans='T')(m_lu, x)
|
||||
x = SolveTriangular(lower=True, unit_diagonal=True, trans='T')(m_lu, x)
|
||||
else:
|
||||
_raise_value_error("mindspore.scipy.linalg.lu_solve input trans must be 0,1 or 2, but got ", trans)
|
||||
x = mnp.reshape(x, b.shape)
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue