forked from mindspore-Ecosystem/mindspore
fix cpu graph mode
This commit is contained in:
parent
bbe9606f1c
commit
142480b08e
|
@ -18,7 +18,7 @@ from ... import numpy as mnp
|
|||
from ...ops import functional as F
|
||||
from ..linalg import solve_triangular
|
||||
from ..linalg import cho_factor, cho_solve
|
||||
from ..utils import _INT_ZERO, _INT_ONE, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
|
||||
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
|
||||
|
||||
|
||||
def gram_schmidt(Q, q):
|
||||
|
@ -57,48 +57,6 @@ def rotate_vectors(H, i, cs, sn):
|
|||
return H
|
||||
|
||||
|
||||
class GivensRotation(nn.Cell):
|
||||
""" do the Givens Rotation"""
|
||||
|
||||
def __init__(self):
|
||||
super(GivensRotation, self).__init__()
|
||||
self.tensor_0 = Tensor(0.0)
|
||||
self.tensor_1 = Tensor(1.0)
|
||||
|
||||
def construct(self, H_row, givens, k):
|
||||
"""
|
||||
Appliy each of the Givens rotations stored in givens[:, :k] to H_row.
|
||||
|
||||
Args:
|
||||
H_row (Tensor): The kth row in (n, n+1) Matrix H
|
||||
givens (Tensor): a (n, 2) Matrix which stores cs, sn for Givens Rotation
|
||||
k (Tensor): the row number, must smaller than n
|
||||
|
||||
Returns:
|
||||
R_row (Tensor): Rotated Vector from H_row
|
||||
givens (Tensor): a (n, 2) Matrix which stores the kth cs, sn values
|
||||
"""
|
||||
i = _INT_ZERO
|
||||
|
||||
while i < k:
|
||||
H_row = rotate_vectors(H_row, i, givens[i, 0], givens[i, 1])
|
||||
i = i + _INT_ONE
|
||||
|
||||
if H_row[k + 1] == self.tensor_0:
|
||||
givens[k, 0] = self.tensor_1
|
||||
givens[k, 1] = self.tensor_0
|
||||
else:
|
||||
increase = mnp.absolute(H_row[k]) < mnp.absolute(H_row[k + 1])
|
||||
t = mnp.where(increase, -H_row[k] /
|
||||
H_row[k + 1], -H_row[k + 1] / H_row[k])
|
||||
r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2)
|
||||
givens[k, 0] = mnp.where(increase, r * t, r)
|
||||
givens[k, 1] = mnp.where(increase, r, r * t)
|
||||
|
||||
R_row = rotate_vectors(H_row, k, givens[k, 0], givens[k, 1])
|
||||
return R_row, givens
|
||||
|
||||
|
||||
class BatchedGmres(nn.Cell):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
||||
|
@ -156,7 +114,6 @@ class IterativeGmres(nn.Cell):
|
|||
super(IterativeGmres, self).__init__()
|
||||
self.A = A
|
||||
self.M = M
|
||||
self.givens_rotation = GivensRotation()
|
||||
|
||||
def construct(self, b, x0, tol, atol, restart, maxiter):
|
||||
A = _normalize_matvec(self.A)
|
||||
|
@ -186,15 +143,30 @@ class IterativeGmres(nn.Cell):
|
|||
k = _INT_ZERO
|
||||
err = r_norm
|
||||
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
||||
V, H, _ = arnoldi_iteration(k, A, M, V, R)
|
||||
R[k, :], givens = self.givens_rotation(H[k, :], givens, k)
|
||||
beta_vec = rotate_vectors(
|
||||
beta_vec, k, givens[k, 0], givens[k, 1])
|
||||
V, R, _ = arnoldi_iteration(k, A, M, V, R)
|
||||
# givens rotation
|
||||
row_k = R[k, :].copy()
|
||||
i = _INT_ZERO
|
||||
while i < k:
|
||||
row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1])
|
||||
i += 1
|
||||
|
||||
if row_k[k + 1] == 0:
|
||||
givens[k, 0] = 1
|
||||
givens[k, 1] = 0
|
||||
else:
|
||||
increase = mnp.absolute(row_k[k]) < mnp.absolute(row_k[k + 1])
|
||||
t = mnp.where(increase, -row_k[k] / row_k[k + 1], -row_k[k + 1] / row_k[k])
|
||||
r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2)
|
||||
givens[k, 0] = mnp.where(increase, r * t, r)
|
||||
givens[k, 1] = mnp.where(increase, r, r * t)
|
||||
|
||||
R[k, :] = rotate_vectors(row_k, k, givens[k, 0], givens[k, 1])
|
||||
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
|
||||
err = mnp.absolute(beta_vec[k + 1])
|
||||
k += 1
|
||||
|
||||
y = solve_triangular(
|
||||
R[:, :-1], beta_vec[:-1], trans='T', lower=True)
|
||||
y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True)
|
||||
dx = mnp.dot(V[:, :-1], y)
|
||||
|
||||
x = x0 + dx
|
||||
|
|
|
@ -139,6 +139,7 @@ def test_gmres_incremental_against_scipy(n, dtype, preconditioner):
|
|||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [5])
|
||||
|
|
Loading…
Reference in New Issue