fix cpu graph mode

This commit is contained in:
zhujingxuan 2021-11-23 15:15:02 +08:00
parent bbe9606f1c
commit 142480b08e
2 changed files with 23 additions and 50 deletions

View File

@ -18,7 +18,7 @@ from ... import numpy as mnp
from ...ops import functional as F from ...ops import functional as F
from ..linalg import solve_triangular from ..linalg import solve_triangular
from ..linalg import cho_factor, cho_solve from ..linalg import cho_factor, cho_solve
from ..utils import _INT_ZERO, _INT_ONE, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
def gram_schmidt(Q, q): def gram_schmidt(Q, q):
@ -57,48 +57,6 @@ def rotate_vectors(H, i, cs, sn):
return H return H
class GivensRotation(nn.Cell):
""" do the Givens Rotation"""
def __init__(self):
super(GivensRotation, self).__init__()
self.tensor_0 = Tensor(0.0)
self.tensor_1 = Tensor(1.0)
def construct(self, H_row, givens, k):
"""
Appliy each of the Givens rotations stored in givens[:, :k] to H_row.
Args:
H_row (Tensor): The kth row in (n, n+1) Matrix H
givens (Tensor): a (n, 2) Matrix which stores cs, sn for Givens Rotation
k (Tensor): the row number, must smaller than n
Returns:
R_row (Tensor): Rotated Vector from H_row
givens (Tensor): a (n, 2) Matrix which stores the kth cs, sn values
"""
i = _INT_ZERO
while i < k:
H_row = rotate_vectors(H_row, i, givens[i, 0], givens[i, 1])
i = i + _INT_ONE
if H_row[k + 1] == self.tensor_0:
givens[k, 0] = self.tensor_1
givens[k, 1] = self.tensor_0
else:
increase = mnp.absolute(H_row[k]) < mnp.absolute(H_row[k + 1])
t = mnp.where(increase, -H_row[k] /
H_row[k + 1], -H_row[k + 1] / H_row[k])
r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2)
givens[k, 0] = mnp.where(increase, r * t, r)
givens[k, 1] = mnp.where(increase, r, r * t)
R_row = rotate_vectors(H_row, k, givens[k, 0], givens[k, 1])
return R_row, givens
class BatchedGmres(nn.Cell): class BatchedGmres(nn.Cell):
""" """
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
@ -156,7 +114,6 @@ class IterativeGmres(nn.Cell):
super(IterativeGmres, self).__init__() super(IterativeGmres, self).__init__()
self.A = A self.A = A
self.M = M self.M = M
self.givens_rotation = GivensRotation()
def construct(self, b, x0, tol, atol, restart, maxiter): def construct(self, b, x0, tol, atol, restart, maxiter):
A = _normalize_matvec(self.A) A = _normalize_matvec(self.A)
@ -186,15 +143,30 @@ class IterativeGmres(nn.Cell):
k = _INT_ZERO k = _INT_ZERO
err = r_norm err = r_norm
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)): while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
V, H, _ = arnoldi_iteration(k, A, M, V, R) V, R, _ = arnoldi_iteration(k, A, M, V, R)
R[k, :], givens = self.givens_rotation(H[k, :], givens, k) # givens rotation
beta_vec = rotate_vectors( row_k = R[k, :].copy()
beta_vec, k, givens[k, 0], givens[k, 1]) i = _INT_ZERO
while i < k:
row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1])
i += 1
if row_k[k + 1] == 0:
givens[k, 0] = 1
givens[k, 1] = 0
else:
increase = mnp.absolute(row_k[k]) < mnp.absolute(row_k[k + 1])
t = mnp.where(increase, -row_k[k] / row_k[k + 1], -row_k[k + 1] / row_k[k])
r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2)
givens[k, 0] = mnp.where(increase, r * t, r)
givens[k, 1] = mnp.where(increase, r, r * t)
R[k, :] = rotate_vectors(row_k, k, givens[k, 0], givens[k, 1])
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
err = mnp.absolute(beta_vec[k + 1]) err = mnp.absolute(beta_vec[k + 1])
k += 1 k += 1
y = solve_triangular( y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True)
R[:, :-1], beta_vec[:-1], trans='T', lower=True)
dx = mnp.dot(V[:, :-1], y) dx = mnp.dot(V[:, :-1], y)
x = x0 + dx x = x0 + dx

View File

@ -139,6 +139,7 @@ def test_gmres_incremental_against_scipy(n, dtype, preconditioner):
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5]) @pytest.mark.parametrize('n', [5])