Add debug message for BFGS method.

This commit is contained in:
hezhenhao1 2021-11-19 15:12:06 +08:00
parent d4e77ed507
commit a527f43699
2 changed files with 5 additions and 6 deletions

View File

@ -16,7 +16,6 @@
from typing import Optional
from typing import NamedTuple
from ... import numpy as mnp
from ...common import Tensor
from ._bfgs import minimize_bfgs
@ -119,7 +118,7 @@ def minimize(func, x0, args=(), *, method, tol=None, options=None) -> OptimizeRe
if method.lower() == 'bfgs':
results = minimize_bfgs(fun_with_args(args), x0, **options)
success = results.converged and mnp.logical_not(results.failed)
success = results.converged and not results.failed
return OptimizeResults(x=results.x_k,
success=success,
status=results.status,
@ -128,6 +127,6 @@ def minimize(func, x0, args=(), *, method, tol=None, options=None) -> OptimizeRe
hess_inv=results.H_k,
nfev=results.nfev,
njev=results.ngev,
nit=results.k)
nit=results.k), results
raise ValueError("Method {} not recognized".format(method))

View File

@ -70,7 +70,7 @@ def test_bfgs(dtype, func_x0):
ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS')
match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res))
match_array(ms_res[0].x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res[1]))
@pytest.mark.level0
@ -91,7 +91,7 @@ def test_bfgs_fixes4594(dtype):
return mnp.mean((mnp.dot(A, x)) ** 2)
results = msp.optimize.minimize(func, Tensor(onp.ones(n, dtype=dtype)), method='BFGS',
options=dict(maxiter=None, gtol=1e-6)).x
options=dict(maxiter=None, gtol=1e-6))[0].x
onp.testing.assert_allclose(results.asnumpy(), onp.zeros(n, dtype=dtype), rtol=1e-6, atol=1e-6)
@ -113,7 +113,7 @@ def test_bfgs_graph(dtype, func_x0):
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)).x
options=dict(maxiter=None, gtol=1e-6))[0].x
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS').x
match_array(ms_res.asnumpy(), scipy_res, error=5)