Remove global constant tensor, float_types, _SafeNormalize of mindspore.scipy module.

This commit is contained in:
hezhenhao1 2022-02-16 14:24:56 +08:00
parent c8c29da927
commit 5b182d6b14
5 changed files with 48 additions and 47 deletions

View File

@ -24,7 +24,6 @@ from .ops import EighNet
from ..ops import operations as P
from ..ops import functional as F
from ..common import dtype as mstype
from .utils import float_types
from .utils_const import _raise_value_error
__all__ = ['block_diag', 'inv', 'eigh', 'lu_factor', 'lu']
@ -195,7 +194,7 @@ def inv(a, overwrite_a=False, check_finite=True):
[[ 1.00000000e+00, 0.00000000e+00],
[ 8.88178420e-16, 1.00000000e+00]])
"""
if F.dtype(a) not in float_types:
if F.dtype(a) not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float32)
matrix_inverse = P.MatrixInverse(adjoint=False)
@ -249,7 +248,7 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
[ 1.00000000e+00, 5.00000000e+00, 2.29330778e+00, 8.55952621e-01],
[ 5.00000000e+00, 1.00000000e+00, 2.00000000e+00, 1.55418575e+00]])
"""
if F.dtype(a) not in float_types:
if F.dtype(a) not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float32)
cholesky_net = Cholesky(lower=lower, clean=False)
c = cholesky_net(a)
@ -292,7 +291,7 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
[[ 1.00000000e+00, 0.00000000e+00],
[ 2.00000000e+00, 1.00000000e+00]])
"""
if F.dtype(a) not in float_types:
if F.dtype(a) not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float32)
cholesky_net = Cholesky(lower=lower, clean=True)
c = cholesky_net(a)
@ -516,7 +515,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
>>> piv
Tensor(shape=[4], dtype=Int32, value= [2, 2, 3, 3])
"""
if F.dtype(a) not in float_types:
if F.dtype(a) not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float32)
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
_raise_value_error("input of lu matrix must be square.")
@ -586,7 +585,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
[ 0.00000000e+00, 0.00000000e+00, -1.03999996e+00, 3.07999992e+00],
[ 0.00000000e+00, -0.00000000e+00, -0.00000000e+00, 7.46153831e+00]])
"""
if F.dtype(a) not in float_types:
if F.dtype(a) not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float32)
msp_lu = LU()
m_lu, _, p = msp_lu(a)

View File

@ -20,8 +20,7 @@ from ... import numpy as mnp
from ...common import Tensor
from .line_search import LineSearch
from ..utils import _to_scalar, grad
from ..utils import _INT_ZERO, _INT_ONE, _BOOL_FALSE
from ..utils import _to_scalar, _to_tensor, grad
class _BFGSResults(NamedTuple):
@ -74,6 +73,11 @@ class MinimizeBfgs(nn.Cell):
self.line_search = LineSearch(func)
def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
# Constant tensors which avoid loop unrolling
_BOOL_FALSE = _to_tensor(False)
_INT_ZERO = _to_tensor(0)
_INT_ONE = _to_tensor(1)
def _my_norm(x, ord_=None):
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))

View File

@ -20,8 +20,7 @@ from ... import numpy as mnp
from ...common import dtype as mstype
from ...common import Tensor
from ..utils import _to_scalar, grad
from ..utils import _to_tensor, _INT_ZERO, _INT_ONE, _BOOL_FALSE
from ..utils import _to_scalar, _to_tensor, grad
class _LineSearchResults(NamedTuple):
@ -92,7 +91,10 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0,
Algorithm 3.6 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61.
Tries cubic, quadratic, and bisection methods of zooming.
"""
# Constant tensors which avoid loop unrolling
_FLOAT_ONE = _to_tensor(1., dtype=a_low.dtype)
_BOOL_FALSE = _to_tensor(False)
_INT_ZERO = _to_tensor(0)
state = {
"done": _BOOL_FALSE,
"failed": _BOOL_FALSE,
@ -200,8 +202,12 @@ class LineSearch(nn.Cell):
gkk = grad(self.func)(xkk)
return fkk, gkk, mnp.dot(gkk, pk)
# Constant tensors which avoid loop unrolling
_FLOAT_ZERO = _to_tensor(0., dtype=xk.dtype)
_FLOAT_ONE = _to_tensor(1., dtype=xk.dtype)
_BOOL_FALSE = _to_tensor(False)
_INT_ZERO = _to_tensor(0)
_INT_ONE = _to_tensor(1)
if old_fval is None or gfk is None:
nfev, ngev = _INT_ONE, _INT_ONE

View File

@ -16,9 +16,10 @@
from ... import nn, ms_function
from ... import numpy as mnp
from ...ops import functional as F
from ...common import dtype as mstype
from ..linalg import solve_triangular
from ..linalg import cho_factor, cho_solve
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps, float_types
from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps
from ..utils_const import _raise_value_error, _raise_type_error
@ -71,6 +72,9 @@ class BatchedGmres(nn.Cell):
self.M = M
def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
dtype = b.dtype
@ -117,6 +121,9 @@ class IterativeGmres(nn.Cell):
self.M = M
def construct(self, b, x0, tol, atol, restart, maxiter):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
@ -269,7 +276,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got {}."
.format(solve_method))
_, x_norm = _safe_normalize(x)
info = mnp.where(mnp.isnan(x_norm), _INT_NEG_ONE, _INT_ZERO)
info = mnp.where(mnp.isnan(x_norm), _to_tensor(-1), _to_tensor(0))
return x, info
@ -284,6 +291,9 @@ class CG(nn.Cell):
self.M = M
def construct(self, b, x0, tol, atol, maxiter):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
@ -386,7 +396,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
_raise_value_error(
'Input x0 and b must have matching shapes: {} vs {}'.format(x0.shape, b.shape))
if (F.dtype(b) not in float_types) or (F.dtype(b) != F.dtype(x0)) or (F.dtype(b) != F.dtype(A)):
if (F.dtype(b) not in (mstype.float32, mstype.float64)) or (F.dtype(b) != F.dtype(x0)) or (
F.dtype(b) != F.dtype(A)):
_raise_type_error('Input A, x0 and b must have same float types')
x = CG(A, M)(b, x0, tol, atol, maxiter)

View File

@ -14,7 +14,7 @@
# ============================================================================
"""internal utility functions"""
import numpy as onp
from .. import nn, ops
from .. import ops
from ..numpy import where, zeros_like, dot, greater
from ..ops import functional as F
from ..common import Tensor
@ -74,40 +74,21 @@ def _eps(x):
return _eps_net(x[(0,) * x.ndim])
class _SafeNormalize(nn.Cell):
def _safe_normalize(x, threshold=None):
"""Normalize method that cast very small results to zero."""
def __init__(self):
"""Initialize LineSearch."""
super(_SafeNormalize, self).__init__()
def construct(self, x, threshold=None):
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
norm = F.pows(x_sum2, 1. / 2.0)
if threshold is None:
if x.dtype in mstype.float_type:
# pick the first element of x to get the eps
threshold = _eps(x)
else:
threshold = 0
use_norm = greater(norm, threshold)
x_norm = x / norm
normalized_x = where(use_norm, x_norm, zeros_like(x))
norm = where(use_norm, norm, zeros_like(norm))
return normalized_x, norm
_safe_normalize = _SafeNormalize()
_INT_ZERO = _to_tensor(0)
_INT_ONE = _to_tensor(1)
_INT_NEG_ONE = _to_tensor(-1)
_FLOAT_ONE = _to_tensor(1.0)
_FLOAT_TWO = _to_tensor(2.0, dtype=float)
_BOOL_TRUE = _to_tensor(True)
_BOOL_FALSE = _to_tensor(False)
float_types = (mstype.float32, mstype.float64)
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
norm = F.pows(x_sum2, 1. / 2.0)
if threshold is None:
if x.dtype in (mstype.float32, mstype.float64):
# pick the first element of x to get the eps
threshold = _eps(x)
else:
threshold = 0
use_norm = greater(norm, threshold)
x_norm = x / norm
normalized_x = where(use_norm, x_norm, zeros_like(x))
norm = where(use_norm, norm, zeros_like(norm))
return normalized_x, norm
@constexpr