forked from mindspore-Ecosystem/mindspore
fix cpu graph mode
This commit is contained in:
parent
bbe9606f1c
commit
142480b08e
|
@ -18,7 +18,7 @@ from ... import numpy as mnp
|
||||||
from ...ops import functional as F
|
from ...ops import functional as F
|
||||||
from ..linalg import solve_triangular
|
from ..linalg import solve_triangular
|
||||||
from ..linalg import cho_factor, cho_solve
|
from ..linalg import cho_factor, cho_solve
|
||||||
from ..utils import _INT_ZERO, _INT_ONE, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
|
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
|
||||||
|
|
||||||
|
|
||||||
def gram_schmidt(Q, q):
|
def gram_schmidt(Q, q):
|
||||||
|
@ -57,48 +57,6 @@ def rotate_vectors(H, i, cs, sn):
|
||||||
return H
|
return H
|
||||||
|
|
||||||
|
|
||||||
class GivensRotation(nn.Cell):
|
|
||||||
""" do the Givens Rotation"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(GivensRotation, self).__init__()
|
|
||||||
self.tensor_0 = Tensor(0.0)
|
|
||||||
self.tensor_1 = Tensor(1.0)
|
|
||||||
|
|
||||||
def construct(self, H_row, givens, k):
|
|
||||||
"""
|
|
||||||
Appliy each of the Givens rotations stored in givens[:, :k] to H_row.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
H_row (Tensor): The kth row in (n, n+1) Matrix H
|
|
||||||
givens (Tensor): a (n, 2) Matrix which stores cs, sn for Givens Rotation
|
|
||||||
k (Tensor): the row number, must smaller than n
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
R_row (Tensor): Rotated Vector from H_row
|
|
||||||
givens (Tensor): a (n, 2) Matrix which stores the kth cs, sn values
|
|
||||||
"""
|
|
||||||
i = _INT_ZERO
|
|
||||||
|
|
||||||
while i < k:
|
|
||||||
H_row = rotate_vectors(H_row, i, givens[i, 0], givens[i, 1])
|
|
||||||
i = i + _INT_ONE
|
|
||||||
|
|
||||||
if H_row[k + 1] == self.tensor_0:
|
|
||||||
givens[k, 0] = self.tensor_1
|
|
||||||
givens[k, 1] = self.tensor_0
|
|
||||||
else:
|
|
||||||
increase = mnp.absolute(H_row[k]) < mnp.absolute(H_row[k + 1])
|
|
||||||
t = mnp.where(increase, -H_row[k] /
|
|
||||||
H_row[k + 1], -H_row[k + 1] / H_row[k])
|
|
||||||
r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2)
|
|
||||||
givens[k, 0] = mnp.where(increase, r * t, r)
|
|
||||||
givens[k, 1] = mnp.where(increase, r, r * t)
|
|
||||||
|
|
||||||
R_row = rotate_vectors(H_row, k, givens[k, 0], givens[k, 1])
|
|
||||||
return R_row, givens
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedGmres(nn.Cell):
|
class BatchedGmres(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
||||||
|
@ -156,7 +114,6 @@ class IterativeGmres(nn.Cell):
|
||||||
super(IterativeGmres, self).__init__()
|
super(IterativeGmres, self).__init__()
|
||||||
self.A = A
|
self.A = A
|
||||||
self.M = M
|
self.M = M
|
||||||
self.givens_rotation = GivensRotation()
|
|
||||||
|
|
||||||
def construct(self, b, x0, tol, atol, restart, maxiter):
|
def construct(self, b, x0, tol, atol, restart, maxiter):
|
||||||
A = _normalize_matvec(self.A)
|
A = _normalize_matvec(self.A)
|
||||||
|
@ -186,15 +143,30 @@ class IterativeGmres(nn.Cell):
|
||||||
k = _INT_ZERO
|
k = _INT_ZERO
|
||||||
err = r_norm
|
err = r_norm
|
||||||
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
||||||
V, H, _ = arnoldi_iteration(k, A, M, V, R)
|
V, R, _ = arnoldi_iteration(k, A, M, V, R)
|
||||||
R[k, :], givens = self.givens_rotation(H[k, :], givens, k)
|
# givens rotation
|
||||||
beta_vec = rotate_vectors(
|
row_k = R[k, :].copy()
|
||||||
beta_vec, k, givens[k, 0], givens[k, 1])
|
i = _INT_ZERO
|
||||||
|
while i < k:
|
||||||
|
row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if row_k[k + 1] == 0:
|
||||||
|
givens[k, 0] = 1
|
||||||
|
givens[k, 1] = 0
|
||||||
|
else:
|
||||||
|
increase = mnp.absolute(row_k[k]) < mnp.absolute(row_k[k + 1])
|
||||||
|
t = mnp.where(increase, -row_k[k] / row_k[k + 1], -row_k[k + 1] / row_k[k])
|
||||||
|
r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2)
|
||||||
|
givens[k, 0] = mnp.where(increase, r * t, r)
|
||||||
|
givens[k, 1] = mnp.where(increase, r, r * t)
|
||||||
|
|
||||||
|
R[k, :] = rotate_vectors(row_k, k, givens[k, 0], givens[k, 1])
|
||||||
|
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
|
||||||
err = mnp.absolute(beta_vec[k + 1])
|
err = mnp.absolute(beta_vec[k + 1])
|
||||||
k += 1
|
k += 1
|
||||||
|
|
||||||
y = solve_triangular(
|
y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True)
|
||||||
R[:, :-1], beta_vec[:-1], trans='T', lower=True)
|
|
||||||
dx = mnp.dot(V[:, :-1], y)
|
dx = mnp.dot(V[:, :-1], y)
|
||||||
|
|
||||||
x = x0 + dx
|
x = x0 + dx
|
||||||
|
|
|
@ -139,6 +139,7 @@ def test_gmres_incremental_against_scipy(n, dtype, preconditioner):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.parametrize('n', [5])
|
@pytest.mark.parametrize('n', [5])
|
||||||
|
|
Loading…
Reference in New Issue