Remove global constant tensor, float_types, _SafeNormalize of mindspore.scipy module.
This commit is contained in:
parent
c8c29da927
commit
5b182d6b14
|
@ -24,7 +24,6 @@ from .ops import EighNet
|
|||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
from ..common import dtype as mstype
|
||||
from .utils import float_types
|
||||
from .utils_const import _raise_value_error
|
||||
|
||||
__all__ = ['block_diag', 'inv', 'eigh', 'lu_factor', 'lu']
|
||||
|
@ -195,7 +194,7 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|||
[[ 1.00000000e+00, 0.00000000e+00],
|
||||
[ 8.88178420e-16, 1.00000000e+00]])
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
if F.dtype(a) not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float32)
|
||||
|
||||
matrix_inverse = P.MatrixInverse(adjoint=False)
|
||||
|
@ -249,7 +248,7 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[ 1.00000000e+00, 5.00000000e+00, 2.29330778e+00, 8.55952621e-01],
|
||||
[ 5.00000000e+00, 1.00000000e+00, 2.00000000e+00, 1.55418575e+00]])
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
if F.dtype(a) not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float32)
|
||||
cholesky_net = Cholesky(lower=lower, clean=False)
|
||||
c = cholesky_net(a)
|
||||
|
@ -292,7 +291,7 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
[[ 1.00000000e+00, 0.00000000e+00],
|
||||
[ 2.00000000e+00, 1.00000000e+00]])
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
if F.dtype(a) not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float32)
|
||||
cholesky_net = Cholesky(lower=lower, clean=True)
|
||||
c = cholesky_net(a)
|
||||
|
@ -516,7 +515,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|||
>>> piv
|
||||
Tensor(shape=[4], dtype=Int32, value= [2, 2, 3, 3])
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
if F.dtype(a) not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float32)
|
||||
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
|
||||
_raise_value_error("input of lu matrix must be square.")
|
||||
|
@ -586,7 +585,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
[ 0.00000000e+00, 0.00000000e+00, -1.03999996e+00, 3.07999992e+00],
|
||||
[ 0.00000000e+00, -0.00000000e+00, -0.00000000e+00, 7.46153831e+00]])
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
if F.dtype(a) not in (mstype.float32, mstype.float64):
|
||||
a = F.cast(a, mstype.float32)
|
||||
msp_lu = LU()
|
||||
m_lu, _, p = msp_lu(a)
|
||||
|
|
|
@ -20,8 +20,7 @@ from ... import numpy as mnp
|
|||
from ...common import Tensor
|
||||
|
||||
from .line_search import LineSearch
|
||||
from ..utils import _to_scalar, grad
|
||||
from ..utils import _INT_ZERO, _INT_ONE, _BOOL_FALSE
|
||||
from ..utils import _to_scalar, _to_tensor, grad
|
||||
|
||||
|
||||
class _BFGSResults(NamedTuple):
|
||||
|
@ -74,6 +73,11 @@ class MinimizeBfgs(nn.Cell):
|
|||
self.line_search = LineSearch(func)
|
||||
|
||||
def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
|
||||
# Constant tensors which avoid loop unrolling
|
||||
_BOOL_FALSE = _to_tensor(False)
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
_INT_ONE = _to_tensor(1)
|
||||
|
||||
def _my_norm(x, ord_=None):
|
||||
if ord_ == mnp.inf:
|
||||
res = mnp.max(mnp.abs(x))
|
||||
|
|
|
@ -20,8 +20,7 @@ from ... import numpy as mnp
|
|||
from ...common import dtype as mstype
|
||||
from ...common import Tensor
|
||||
|
||||
from ..utils import _to_scalar, grad
|
||||
from ..utils import _to_tensor, _INT_ZERO, _INT_ONE, _BOOL_FALSE
|
||||
from ..utils import _to_scalar, _to_tensor, grad
|
||||
|
||||
|
||||
class _LineSearchResults(NamedTuple):
|
||||
|
@ -92,7 +91,10 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0,
|
|||
Algorithm 3.6 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61.
|
||||
Tries cubic, quadratic, and bisection methods of zooming.
|
||||
"""
|
||||
# Constant tensors which avoid loop unrolling
|
||||
_FLOAT_ONE = _to_tensor(1., dtype=a_low.dtype)
|
||||
_BOOL_FALSE = _to_tensor(False)
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
state = {
|
||||
"done": _BOOL_FALSE,
|
||||
"failed": _BOOL_FALSE,
|
||||
|
@ -200,8 +202,12 @@ class LineSearch(nn.Cell):
|
|||
gkk = grad(self.func)(xkk)
|
||||
return fkk, gkk, mnp.dot(gkk, pk)
|
||||
|
||||
# Constant tensors which avoid loop unrolling
|
||||
_FLOAT_ZERO = _to_tensor(0., dtype=xk.dtype)
|
||||
_FLOAT_ONE = _to_tensor(1., dtype=xk.dtype)
|
||||
_BOOL_FALSE = _to_tensor(False)
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
_INT_ONE = _to_tensor(1)
|
||||
|
||||
if old_fval is None or gfk is None:
|
||||
nfev, ngev = _INT_ONE, _INT_ONE
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
from ... import nn, ms_function
|
||||
from ... import numpy as mnp
|
||||
from ...ops import functional as F
|
||||
from ...common import dtype as mstype
|
||||
from ..linalg import solve_triangular
|
||||
from ..linalg import cho_factor, cho_solve
|
||||
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps, float_types
|
||||
from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps
|
||||
from ..utils_const import _raise_value_error, _raise_type_error
|
||||
|
||||
|
||||
|
@ -71,6 +72,9 @@ class BatchedGmres(nn.Cell):
|
|||
self.M = M
|
||||
|
||||
def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
dtype = b.dtype
|
||||
|
@ -117,6 +121,9 @@ class IterativeGmres(nn.Cell):
|
|||
self.M = M
|
||||
|
||||
def construct(self, b, x0, tol, atol, restart, maxiter):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
|
||||
|
@ -269,7 +276,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
|||
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got {}."
|
||||
.format(solve_method))
|
||||
_, x_norm = _safe_normalize(x)
|
||||
info = mnp.where(mnp.isnan(x_norm), _INT_NEG_ONE, _INT_ZERO)
|
||||
info = mnp.where(mnp.isnan(x_norm), _to_tensor(-1), _to_tensor(0))
|
||||
return x, info
|
||||
|
||||
|
||||
|
@ -284,6 +291,9 @@ class CG(nn.Cell):
|
|||
self.M = M
|
||||
|
||||
def construct(self, b, x0, tol, atol, maxiter):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
|
||||
|
@ -386,7 +396,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
|||
_raise_value_error(
|
||||
'Input x0 and b must have matching shapes: {} vs {}'.format(x0.shape, b.shape))
|
||||
|
||||
if (F.dtype(b) not in float_types) or (F.dtype(b) != F.dtype(x0)) or (F.dtype(b) != F.dtype(A)):
|
||||
if (F.dtype(b) not in (mstype.float32, mstype.float64)) or (F.dtype(b) != F.dtype(x0)) or (
|
||||
F.dtype(b) != F.dtype(A)):
|
||||
_raise_type_error('Input A, x0 and b must have same float types')
|
||||
|
||||
x = CG(A, M)(b, x0, tol, atol, maxiter)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""internal utility functions"""
|
||||
import numpy as onp
|
||||
from .. import nn, ops
|
||||
from .. import ops
|
||||
from ..numpy import where, zeros_like, dot, greater
|
||||
from ..ops import functional as F
|
||||
from ..common import Tensor
|
||||
|
@ -74,40 +74,21 @@ def _eps(x):
|
|||
return _eps_net(x[(0,) * x.ndim])
|
||||
|
||||
|
||||
class _SafeNormalize(nn.Cell):
|
||||
def _safe_normalize(x, threshold=None):
|
||||
"""Normalize method that cast very small results to zero."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LineSearch."""
|
||||
super(_SafeNormalize, self).__init__()
|
||||
|
||||
def construct(self, x, threshold=None):
|
||||
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
|
||||
norm = F.pows(x_sum2, 1. / 2.0)
|
||||
if threshold is None:
|
||||
if x.dtype in mstype.float_type:
|
||||
# pick the first element of x to get the eps
|
||||
threshold = _eps(x)
|
||||
else:
|
||||
threshold = 0
|
||||
use_norm = greater(norm, threshold)
|
||||
x_norm = x / norm
|
||||
normalized_x = where(use_norm, x_norm, zeros_like(x))
|
||||
norm = where(use_norm, norm, zeros_like(norm))
|
||||
return normalized_x, norm
|
||||
|
||||
|
||||
_safe_normalize = _SafeNormalize()
|
||||
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
_INT_ONE = _to_tensor(1)
|
||||
_INT_NEG_ONE = _to_tensor(-1)
|
||||
_FLOAT_ONE = _to_tensor(1.0)
|
||||
_FLOAT_TWO = _to_tensor(2.0, dtype=float)
|
||||
_BOOL_TRUE = _to_tensor(True)
|
||||
_BOOL_FALSE = _to_tensor(False)
|
||||
|
||||
float_types = (mstype.float32, mstype.float64)
|
||||
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
|
||||
norm = F.pows(x_sum2, 1. / 2.0)
|
||||
if threshold is None:
|
||||
if x.dtype in (mstype.float32, mstype.float64):
|
||||
# pick the first element of x to get the eps
|
||||
threshold = _eps(x)
|
||||
else:
|
||||
threshold = 0
|
||||
use_norm = greater(norm, threshold)
|
||||
x_norm = x / norm
|
||||
normalized_x = where(use_norm, x_norm, zeros_like(x))
|
||||
norm = where(use_norm, norm, zeros_like(norm))
|
||||
return normalized_x, norm
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
Loading…
Reference in New Issue