forked from mindspore-Ecosystem/mindspore
Add BFGS, line_search and block_diag algorithms in scipy.
This commit is contained in:
parent
3452d6e55d
commit
dd829f53f9
|
@ -278,6 +278,7 @@ install(
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/parallel
|
${CMAKE_SOURCE_DIR}/mindspore/parallel
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/mindrecord
|
${CMAKE_SOURCE_DIR}/mindspore/mindrecord
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/numpy
|
${CMAKE_SOURCE_DIR}/mindspore/numpy
|
||||||
|
${CMAKE_SOURCE_DIR}/mindspore/scipy
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/train
|
${CMAKE_SOURCE_DIR}/mindspore/train
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/boost
|
${CMAKE_SOURCE_DIR}/mindspore/boost
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/common
|
${CMAKE_SOURCE_DIR}/mindspore/common
|
||||||
|
|
|
@ -13,3 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Scipy-like interfaces in mindspore."""
|
"""Scipy-like interfaces in mindspore."""
|
||||||
|
|
||||||
|
from . import optimize, linalg
|
||||||
|
from .optimize import *
|
||||||
|
from .linalg import *
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
__all__.extend(optimize.__all__)
|
||||||
|
__all__.extend(linalg.__all__)
|
||||||
|
|
||||||
|
__all__.sort()
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Linear algebra submodule"""
|
||||||
|
from .. import numpy as mnp
|
||||||
|
from .. import ops
|
||||||
|
|
||||||
|
__all__ = ['block_diag']
|
||||||
|
|
||||||
|
|
||||||
|
def block_diag(*arrs):
|
||||||
|
"""
|
||||||
|
Create a block diagonal matrix from provided arrays.
|
||||||
|
|
||||||
|
Given the inputs `A`, `B` and `C`, the output will have these
|
||||||
|
Tensor arranged on the diagonal::
|
||||||
|
|
||||||
|
[[A, 0, 0],
|
||||||
|
[0, B, 0],
|
||||||
|
[0, 0, C]]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
A, B, C, ... (Tensor): up to 2-D
|
||||||
|
Input Tensors. A 1-D Tensor or a 2-D Tensor with shape ``(1,n)``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
D (Tesnor): Tensor with `A`, `B`, `C`, ... on the diagonal. `D` has
|
||||||
|
the same dtype as `A`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are tensors with dimensions higher than 2 in all arguments.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as onp
|
||||||
|
>>> from mindspore.common import Tensor
|
||||||
|
>>> from mindspore.scipy.linalg import block_diag
|
||||||
|
>>> A = Tensor(onp.array([[1, 0], [0, 1]]))
|
||||||
|
>>> B = Tensor(onp.array([[3, 4, 5], [6, 7, 8]]))
|
||||||
|
>>> C = Tensor(onp.array([[7]]))
|
||||||
|
>>> P = Tensor(onp.zeros((2, ), dtype='int32'))
|
||||||
|
>>> block_diag(A, B, C)
|
||||||
|
[[1, 0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0, 0],
|
||||||
|
[0, 0, 3, 4, 5, 0],
|
||||||
|
[0, 0, 6, 7, 8, 0],
|
||||||
|
[0, 0, 0, 0, 0, 7]]
|
||||||
|
>>> block_diag(A, P, B, C)
|
||||||
|
[[1, 0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 3, 4, 5, 0],
|
||||||
|
[0, 0, 6, 7, 8, 0],
|
||||||
|
[0, 0, 0, 0, 0, 7]]
|
||||||
|
"""
|
||||||
|
if not arrs:
|
||||||
|
return mnp.zeros((1, 0))
|
||||||
|
bad_shapes = [i for i, a in enumerate(arrs) if a.ndim > 2]
|
||||||
|
if bad_shapes:
|
||||||
|
raise ValueError("Arguments to mindspore.scipy.linalg.block_diag must have at "
|
||||||
|
"most 2 dimensions, got {} at argument {}."
|
||||||
|
.format(arrs[bad_shapes[0]], bad_shapes[0]))
|
||||||
|
arrs = [mnp.atleast_2d(a) for a in arrs]
|
||||||
|
accum = arrs[0]
|
||||||
|
for arr in arrs[1:]:
|
||||||
|
_, c = arr.shape
|
||||||
|
arr = ops.Pad(((0, 0), (accum.shape[-1], 0)))(arr)
|
||||||
|
accum = ops.Pad(((0, 0), (0, c)))(accum)
|
||||||
|
accum = mnp.concatenate([accum, arr], axis=0)
|
||||||
|
return accum
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Optimize submodule"""
|
||||||
|
from .minimize import minimize
|
||||||
|
from .line_search import line_search
|
||||||
|
|
||||||
|
__all__ = ["minimize", "line_search"]
|
|
@ -0,0 +1,228 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""BFGS"""
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from ... import nn
|
||||||
|
from ... import numpy as mnp
|
||||||
|
from ...common import Tensor
|
||||||
|
from ...ops import functional as F
|
||||||
|
|
||||||
|
from .line_search import LineSearch
|
||||||
|
from ..utils import _to_scalar
|
||||||
|
from ..utils import _INT_ZERO, _INT_ONE, _BOOL_FALSE
|
||||||
|
|
||||||
|
|
||||||
|
class _BFGSResults(NamedTuple):
|
||||||
|
"""Results from BFGS optimization.
|
||||||
|
|
||||||
|
Arg
|
||||||
|
converged (bool): `True`` if minimization converged.
|
||||||
|
failed (bool): `True`` if line search failed.
|
||||||
|
k (int): the number of iterations of the BFGS update.
|
||||||
|
nfev (int): total number of objective evaluations performed.
|
||||||
|
ngev (int): total number of jacobian evaluations
|
||||||
|
nhev (int): total number of hessian evaluations
|
||||||
|
x_k (Tensor): containing the last argument value found during the search. If
|
||||||
|
the search converged, then this value is the argmin of the objective
|
||||||
|
function.
|
||||||
|
f_k (float): containing the value of the objective function at `x_k`. If the
|
||||||
|
search converged, then this is the (local) minimum of the objective
|
||||||
|
function.
|
||||||
|
g_k (Tensor): containing the gradient of the objective function at `x_k`. If
|
||||||
|
the search converged the l2-norm of this tensor should be below the
|
||||||
|
tolerance.
|
||||||
|
H_k (Tensor): containing the inverse of the estimated Hessian.
|
||||||
|
old_old_fval (float): Function value for the point preceding x=x_k.
|
||||||
|
status (int): describing end state.
|
||||||
|
line_search_status (int): describing line search end state (only means
|
||||||
|
something if line search fails).
|
||||||
|
"""
|
||||||
|
converged: bool
|
||||||
|
failed: bool
|
||||||
|
k: int
|
||||||
|
nfev: int
|
||||||
|
ngev: int
|
||||||
|
nhev: int
|
||||||
|
x_k: Tensor
|
||||||
|
f_k: float
|
||||||
|
g_k: Tensor
|
||||||
|
H_k: Tensor
|
||||||
|
old_old_fval: float
|
||||||
|
status: int
|
||||||
|
line_search_status: int
|
||||||
|
|
||||||
|
|
||||||
|
class MinimizeBfgs(nn.Cell):
|
||||||
|
"""minimize bfgs"""
|
||||||
|
|
||||||
|
def __init__(self, func):
|
||||||
|
"""Initialize MinimizeBfgs."""
|
||||||
|
super(MinimizeBfgs, self).__init__()
|
||||||
|
self.func = func
|
||||||
|
self.line_search = LineSearch(func)
|
||||||
|
|
||||||
|
def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
|
||||||
|
def _my_norm(x, ord_=None):
|
||||||
|
if ord_ == mnp.inf:
|
||||||
|
res = mnp.max(mnp.abs(x))
|
||||||
|
else:
|
||||||
|
res = mnp.sqrt(mnp.sum(x ** 2))
|
||||||
|
return res
|
||||||
|
|
||||||
|
if maxiter is None:
|
||||||
|
maxiter = mnp.size(x0) * 200
|
||||||
|
|
||||||
|
d = x0.shape[0]
|
||||||
|
initial_H = mnp.eye(d, dtype=x0.dtype)
|
||||||
|
f_0 = self.func(x0)
|
||||||
|
g_0 = F.grad(self.func, grad_first_param=True)(x0)
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"converged": _my_norm(g_0, ord_=mnp.inf) < gtol,
|
||||||
|
"failed": _BOOL_FALSE,
|
||||||
|
"k": _INT_ZERO,
|
||||||
|
"nfev": _INT_ONE,
|
||||||
|
"ngev": _INT_ONE,
|
||||||
|
"nhev": _INT_ZERO,
|
||||||
|
"x_k": x0,
|
||||||
|
"f_k": f_0,
|
||||||
|
"g_k": g_0,
|
||||||
|
"H_k": initial_H,
|
||||||
|
"old_old_fval": f_0 + _my_norm(g_0) / 2,
|
||||||
|
"status": _INT_ZERO,
|
||||||
|
"line_search_status": _INT_ZERO
|
||||||
|
}
|
||||||
|
|
||||||
|
while state["k"] < maxiter:
|
||||||
|
p_k = -1 * mnp.dot(state["H_k"], state["g_k"])
|
||||||
|
line_search_results = self.line_search(state["x_k"],
|
||||||
|
p_k,
|
||||||
|
old_fval=state["f_k"],
|
||||||
|
old_old_fval=state["old_old_fval"],
|
||||||
|
gfk=state["g_k"],
|
||||||
|
maxiter=line_search_maxiter)
|
||||||
|
state["nfev"] += line_search_results["nfev"]
|
||||||
|
state["ngev"] += line_search_results["ngev"]
|
||||||
|
state["failed"] = line_search_results["failed"] or mnp.logical_not(line_search_results["done"])
|
||||||
|
state["line_search_status"] = line_search_results["status"]
|
||||||
|
|
||||||
|
if state["failed"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
s_k = line_search_results["a_star"] * p_k
|
||||||
|
x_kp1 = state["x_k"] + s_k
|
||||||
|
f_kp1 = line_search_results["phi_star"]
|
||||||
|
g_kp1 = line_search_results["g_star"]
|
||||||
|
y_k = g_kp1 - state["g_k"]
|
||||||
|
|
||||||
|
state["old_old_fval"] = state["f_k"]
|
||||||
|
state["converged"] = _my_norm(g_kp1, ord_=norm) < gtol
|
||||||
|
state["x_k"] = x_kp1
|
||||||
|
state["f_k"] = f_kp1
|
||||||
|
state["g_k"] = g_kp1
|
||||||
|
|
||||||
|
if state["converged"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
rho_k = mnp.reciprocal(mnp.dot(y_k, s_k))
|
||||||
|
sy_k = mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(y_k, axis=0)
|
||||||
|
term1 = rho_k * sy_k
|
||||||
|
sy_k_2 = mnp.expand_dims(y_k, axis=1) * mnp.expand_dims(s_k, axis=0)
|
||||||
|
term2 = rho_k * sy_k_2
|
||||||
|
I = mnp.eye(d)
|
||||||
|
term3 = mnp.matmul(I - term1, state["H_k"])
|
||||||
|
term4 = mnp.matmul(term3, I - term2)
|
||||||
|
term5 = rho_k * (mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(s_k, axis=0))
|
||||||
|
H_kp1 = term4 + term5
|
||||||
|
state["H_k"] = H_kp1
|
||||||
|
|
||||||
|
# next iteration
|
||||||
|
state["k"] = state["k"] + 1
|
||||||
|
|
||||||
|
status = mnp.where(
|
||||||
|
state["converged"],
|
||||||
|
0, # converged
|
||||||
|
mnp.where(
|
||||||
|
state["k"] == maxiter,
|
||||||
|
1, # max iters reached
|
||||||
|
mnp.where(
|
||||||
|
state["failed"],
|
||||||
|
2 + state["line_search_status"], # ls failed (+ reason)
|
||||||
|
-1, # undefined
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
state["status"] = status
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def minimize_bfgs(func, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
|
||||||
|
"""Minimize a function using BFGS.
|
||||||
|
|
||||||
|
Implements the BFGS algorithm from
|
||||||
|
Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg.
|
||||||
|
136-143.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (Callable): function of the form f(x) where x is a flat Tensor and returns a real
|
||||||
|
scalar. The function should be composed of operations with vjp defined.
|
||||||
|
x0 (Tensor): initial guess.
|
||||||
|
maxiter (int, optional): maximum number of iterations.
|
||||||
|
norm (float): order of norm for convergence check. Default inf.
|
||||||
|
gtol (float): terminates minimization when |grad|_norm < g_tol.
|
||||||
|
line_search_maxiter (int): maximum number of linesearch iterations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimization result.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU`` ``GPU``
|
||||||
|
"""
|
||||||
|
if maxiter is None:
|
||||||
|
maxiter = mnp.size(x0) * 200
|
||||||
|
|
||||||
|
state = MinimizeBfgs(func)(x0, maxiter, norm, gtol)
|
||||||
|
# If running in graph mode, the state is a tuple.
|
||||||
|
if isinstance(state, tuple):
|
||||||
|
state = _BFGSResults(converged=_to_scalar(state[0]),
|
||||||
|
failed=_to_scalar(state[1]),
|
||||||
|
k=_to_scalar(state[2]),
|
||||||
|
nfev=_to_scalar(state[3]),
|
||||||
|
ngev=_to_scalar(state[4]),
|
||||||
|
nhev=_to_scalar(state[5]),
|
||||||
|
x_k=state[6],
|
||||||
|
f_k=_to_scalar(state[7]),
|
||||||
|
g_k=state[8],
|
||||||
|
H_k=state[9],
|
||||||
|
old_old_fval=_to_scalar(state[10]),
|
||||||
|
status=_to_scalar(state[11]),
|
||||||
|
line_search_status=_to_scalar(state[12]))
|
||||||
|
else:
|
||||||
|
state = _BFGSResults(converged=_to_scalar(state["converged"]),
|
||||||
|
failed=_to_scalar(state["failed"]),
|
||||||
|
k=_to_scalar(state["k"]),
|
||||||
|
nfev=_to_scalar(state["nfev"]),
|
||||||
|
ngev=_to_scalar(state["ngev"]),
|
||||||
|
nhev=_to_scalar(state["nhev"]),
|
||||||
|
x_k=state["x_k"],
|
||||||
|
f_k=_to_scalar(state["f_k"]),
|
||||||
|
g_k=state["g_k"],
|
||||||
|
H_k=state["H_k"],
|
||||||
|
old_old_fval=_to_scalar(state["old_old_fval"]),
|
||||||
|
status=_to_scalar(state["status"]),
|
||||||
|
line_search_status=_to_scalar(state["line_search_status"]))
|
||||||
|
|
||||||
|
return state
|
|
@ -0,0 +1,343 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""line search"""
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from ... import nn
|
||||||
|
from ... import numpy as mnp
|
||||||
|
from ...common import dtype as mstype
|
||||||
|
from ...common import Tensor
|
||||||
|
from ...ops import functional as F
|
||||||
|
|
||||||
|
from ..utils import _to_scalar
|
||||||
|
from ..utils import _to_tensor, _FLOAT_ZERO, _FLOAT_ONE, _INT_ZERO, _INT_ONE, _BOOL_FALSE
|
||||||
|
|
||||||
|
|
||||||
|
class _LineSearchResults(NamedTuple):
|
||||||
|
"""Results of line search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failed (bool): `True`` if the strong Wolfe criteria were satisfied
|
||||||
|
nit (int): number of iterations
|
||||||
|
nfev (int): number of functions evaluations
|
||||||
|
ngev (int): number of gradients evaluations
|
||||||
|
k (int): number of iterations
|
||||||
|
a_k (float): step size
|
||||||
|
f_k (float): final function value
|
||||||
|
g_k (Tensor): final gradient value
|
||||||
|
status (int): end status
|
||||||
|
"""
|
||||||
|
failed: bool
|
||||||
|
nit: int
|
||||||
|
nfev: int
|
||||||
|
ngev: int
|
||||||
|
k: int
|
||||||
|
a_k: float
|
||||||
|
f_k: float
|
||||||
|
g_k: Tensor
|
||||||
|
status: int
|
||||||
|
|
||||||
|
|
||||||
|
def _cubicmin(a, fa, fpa, b, fb, c, fc):
|
||||||
|
"""Finds the minimizer for a cubic polynomial that goes through the
|
||||||
|
points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa.
|
||||||
|
"""
|
||||||
|
C = fpa
|
||||||
|
db = b - a
|
||||||
|
dc = c - a
|
||||||
|
denom = (db * dc) ** 2 * (db - dc)
|
||||||
|
|
||||||
|
d1 = mnp.zeros((2, 2))
|
||||||
|
d1[0, 0] = dc ** 2
|
||||||
|
d1[0, 1] = -db ** 2
|
||||||
|
d1[1, 0] = -dc ** 3
|
||||||
|
d1[1, 1] = db ** 3
|
||||||
|
|
||||||
|
d2 = mnp.zeros((2,))
|
||||||
|
d2[0] = fb - fa - C * db
|
||||||
|
d2[1] = fc - fa - C * dc
|
||||||
|
|
||||||
|
A, B = mnp.dot(d1, d2.flatten()) / denom
|
||||||
|
|
||||||
|
radical = B * B - 3. * A * C
|
||||||
|
xmin = a + (-B + mnp.sqrt(radical)) / (3. * A)
|
||||||
|
return xmin
|
||||||
|
|
||||||
|
|
||||||
|
def _quadmin(a, fa, fpa, b, fb):
|
||||||
|
"""Finds the minimizer for a quadratic polynomial that goes through
|
||||||
|
the points (a,fa), (b,fb) with derivative at a of fpa.
|
||||||
|
"""
|
||||||
|
D = fa
|
||||||
|
C = fpa
|
||||||
|
db = b - a
|
||||||
|
B = (fb - D - C * db) / (db ** 2)
|
||||||
|
xmin = a - C / (2. * B)
|
||||||
|
return xmin
|
||||||
|
|
||||||
|
|
||||||
|
def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0, dphi_0, c1, c2, is_run):
|
||||||
|
"""Implementation of zoom algorithm.
|
||||||
|
Algorithm 3.6 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61.
|
||||||
|
Tries cubic, quadratic, and bisection methods of zooming.
|
||||||
|
"""
|
||||||
|
state = {
|
||||||
|
"done": _BOOL_FALSE,
|
||||||
|
"failed": _BOOL_FALSE,
|
||||||
|
"j": _INT_ZERO,
|
||||||
|
"a_low": a_low,
|
||||||
|
"phi_low": phi_low,
|
||||||
|
"dphi_low": dphi_low,
|
||||||
|
"a_high": a_high,
|
||||||
|
"phi_high": phi_high,
|
||||||
|
"dphi_high": dphi_high,
|
||||||
|
"a_rec": (a_low + a_high) / 2.,
|
||||||
|
"phi_rec": (phi_low + phi_high) / 2.,
|
||||||
|
"a_star": _FLOAT_ONE,
|
||||||
|
"phi_star": phi_low,
|
||||||
|
"dphi_star": dphi_low,
|
||||||
|
"g_star": g_0,
|
||||||
|
"nfev": _INT_ZERO,
|
||||||
|
"ngev": _INT_ZERO,
|
||||||
|
}
|
||||||
|
|
||||||
|
if mnp.logical_not(is_run):
|
||||||
|
return state
|
||||||
|
|
||||||
|
delta1 = 0.2
|
||||||
|
delta2 = 0.1
|
||||||
|
maxiter = 10 # scipy: 10 jax: 30
|
||||||
|
while mnp.logical_not(state["done"]) and state["j"] < maxiter:
|
||||||
|
dalpha = state["a_high"] - state["a_low"]
|
||||||
|
a = mnp.minimum(state["a_low"], state["a_high"])
|
||||||
|
b = mnp.maximum(state["a_low"], state["a_high"])
|
||||||
|
|
||||||
|
cchk = delta1 * dalpha
|
||||||
|
qchk = delta2 * dalpha
|
||||||
|
|
||||||
|
a_j_cubic = _cubicmin(state["a_low"], state["phi_low"], state["dphi_low"], state["a_high"],
|
||||||
|
state["phi_high"], state["a_rec"], state["phi_rec"])
|
||||||
|
use_cubic = state["j"] > 0 and mnp.isfinite(a_j_cubic) and \
|
||||||
|
mnp.logical_and(a_j_cubic > a + cchk, a_j_cubic < b - cchk)
|
||||||
|
|
||||||
|
a_j_quad = _quadmin(state["a_low"], state["phi_low"], state["dphi_low"], state["a_high"],
|
||||||
|
state["phi_high"])
|
||||||
|
use_quad = mnp.logical_not(use_cubic) and mnp.isfinite(a_j_quad) and \
|
||||||
|
mnp.logical_and(a_j_quad > a + qchk, a_j_quad < b - qchk)
|
||||||
|
|
||||||
|
a_j_bisection = (state["a_low"] + state["a_high"]) / 2.0
|
||||||
|
use_bisection = mnp.logical_not(use_cubic) and mnp.logical_not(use_quad)
|
||||||
|
|
||||||
|
a_j = mnp.where(use_cubic, a_j_cubic, state["a_rec"])
|
||||||
|
a_j = mnp.where(use_quad, a_j_quad, a_j)
|
||||||
|
a_j = mnp.where(use_bisection, a_j_bisection, a_j)
|
||||||
|
|
||||||
|
phi_j, g_j, dphi_j = fn(a_j)
|
||||||
|
state["nfev"] += 1
|
||||||
|
state["ngev"] += 1
|
||||||
|
|
||||||
|
j_to_high = (phi_j > phi_0 + c1 * a_j * dphi_0) or (phi_j >= state["phi_low"])
|
||||||
|
state["a_rec"] = mnp.where(j_to_high, state["a_high"], state["a_rec"])
|
||||||
|
state["phi_rec"] = mnp.where(j_to_high, state["phi_high"], state["phi_rec"])
|
||||||
|
state["a_high"] = mnp.where(j_to_high, a_j, state["a_high"])
|
||||||
|
state["phi_high"] = mnp.where(j_to_high, phi_j, state["phi_high"])
|
||||||
|
state["dphi_high"] = mnp.where(j_to_high, dphi_j, state["dphi_high"])
|
||||||
|
|
||||||
|
j_to_star = mnp.logical_not(j_to_high) and mnp.abs(dphi_j) <= -c2 * dphi_0
|
||||||
|
state["done"] = j_to_star
|
||||||
|
state["a_star"] = mnp.where(j_to_star, a_j, state["a_star"])
|
||||||
|
state["phi_star"] = mnp.where(j_to_star, phi_j, state["phi_star"])
|
||||||
|
state["g_star"] = mnp.where(j_to_star, g_j, state["g_star"])
|
||||||
|
state["dphi_star"] = mnp.where(j_to_star, dphi_j, state["dphi_star"])
|
||||||
|
|
||||||
|
low_to_high = mnp.logical_not(j_to_high) and mnp.logical_not(j_to_star) and \
|
||||||
|
dphi_j * (state["a_high"] - state["a_low"]) >= 0.
|
||||||
|
state["a_rec"] = mnp.where(low_to_high, state["a_high"], state["a_rec"])
|
||||||
|
state["phi_rec"] = mnp.where(low_to_high, state["phi_high"], state["phi_rec"])
|
||||||
|
state["a_high"] = mnp.where(low_to_high, a_low, state["a_high"])
|
||||||
|
state["phi_high"] = mnp.where(low_to_high, phi_low, state["phi_high"])
|
||||||
|
state["dphi_high"] = mnp.where(low_to_high, dphi_low, state["dphi_high"])
|
||||||
|
|
||||||
|
j_to_low = mnp.logical_not(j_to_high) and mnp.logical_not(j_to_star)
|
||||||
|
state["a_rec"] = mnp.where(j_to_low, state["a_low"], state["a_rec"])
|
||||||
|
state["phi_rec"] = mnp.where(j_to_low, state["phi_low"], state["phi_rec"])
|
||||||
|
state["a_low"] = mnp.where(j_to_low, a_j, state["a_low"])
|
||||||
|
state["phi_low"] = mnp.where(j_to_low, phi_j, state["phi_low"])
|
||||||
|
state["dphi_low"] = mnp.where(j_to_low, dphi_j, state["dphi_low"])
|
||||||
|
|
||||||
|
# next iteration
|
||||||
|
state["j"] = state["j"] + 1
|
||||||
|
|
||||||
|
state["failed"] = state["j"] == maxiter
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class LineSearch(nn.Cell):
|
||||||
|
"""Line Search that satisfies strong Wolfe conditions."""
|
||||||
|
|
||||||
|
def __init__(self, func):
|
||||||
|
"""Initialize LineSearch."""
|
||||||
|
super(LineSearch, self).__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def construct(self, xk, pk, old_fval=None, old_old_fval=None, gfk=None,
|
||||||
|
c1=1e-4, c2=0.9, maxiter=3):
|
||||||
|
def fval_and_grad(alpha):
|
||||||
|
xkk = xk + alpha * pk
|
||||||
|
fkk = self.func(xkk)
|
||||||
|
gkk = F.grad(self.func, grad_first_param=True)(xkk)
|
||||||
|
return fkk, gkk, mnp.dot(gkk, pk)
|
||||||
|
|
||||||
|
if old_fval is None or gfk is None:
|
||||||
|
nfev, ngev = _INT_ONE, _INT_ONE
|
||||||
|
phi_0, g_0, dphi_0 = fval_and_grad(_FLOAT_ZERO)
|
||||||
|
else:
|
||||||
|
nfev, ngev = _INT_ZERO, _INT_ZERO
|
||||||
|
phi_0, g_0 = old_fval, gfk
|
||||||
|
dphi_0 = mnp.dot(g_0, pk)
|
||||||
|
|
||||||
|
if old_old_fval is None:
|
||||||
|
start_value = _FLOAT_ONE
|
||||||
|
else:
|
||||||
|
old_phi0 = old_old_fval
|
||||||
|
candidate_start_value = 1.01 * 2 * (phi_0 - old_phi0) / dphi_0
|
||||||
|
start_value = mnp.minimum(candidate_start_value, _FLOAT_ONE)
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"done": _BOOL_FALSE,
|
||||||
|
"failed": _BOOL_FALSE,
|
||||||
|
"i": _INT_ONE,
|
||||||
|
"a_i": _FLOAT_ZERO,
|
||||||
|
"phi_i": phi_0,
|
||||||
|
"dphi_i": dphi_0,
|
||||||
|
"nfev": nfev,
|
||||||
|
"ngev": ngev,
|
||||||
|
"a_star": _FLOAT_ZERO,
|
||||||
|
"phi_star": phi_0,
|
||||||
|
"dphi_star": dphi_0,
|
||||||
|
"g_star": g_0,
|
||||||
|
}
|
||||||
|
|
||||||
|
while mnp.logical_not(state["done"]) and state["i"] <= maxiter:
|
||||||
|
a_i = mnp.where(state["i"] > 1, state["a_i"] * 2.0, start_value)
|
||||||
|
phi_i, g_i, dphi_i = fval_and_grad(a_i)
|
||||||
|
state["nfev"] += 1
|
||||||
|
state["ngev"] += 1
|
||||||
|
|
||||||
|
# Armijo condition
|
||||||
|
cond1 = (phi_i > phi_0 + c1 * a_i * dphi_0) or \
|
||||||
|
(phi_i >= state["phi_i"] and state["i"] > 1)
|
||||||
|
zoom1 = _zoom(fval_and_grad, state["a_i"], state["phi_i"], state["dphi_i"],
|
||||||
|
a_i, phi_i, dphi_i, phi_0, g_0, dphi_0, c1, c2, cond1)
|
||||||
|
state["nfev"] += zoom1["nfev"]
|
||||||
|
state["ngev"] += zoom1["ngev"]
|
||||||
|
state["done"] = cond1
|
||||||
|
state["failed"] = cond1 and zoom1["failed"]
|
||||||
|
state["a_star"] = mnp.where(cond1, zoom1["a_star"], state["a_star"])
|
||||||
|
state["phi_star"] = mnp.where(cond1, zoom1["phi_star"], state["phi_star"])
|
||||||
|
state["g_star"] = mnp.where(cond1, zoom1["g_star"], state["g_star"])
|
||||||
|
state["dphi_star"] = mnp.where(cond1, zoom1["dphi_star"], state["dphi_star"])
|
||||||
|
|
||||||
|
# curvature condition
|
||||||
|
cond2 = mnp.logical_not(cond1) and mnp.abs(dphi_i) <= -c2 * dphi_0
|
||||||
|
state["done"] = state["done"] or cond2
|
||||||
|
state["a_star"] = mnp.where(cond2, a_i, state["a_star"])
|
||||||
|
state["phi_star"] = mnp.where(cond2, phi_i, state["phi_star"])
|
||||||
|
state["g_star"] = mnp.where(cond2, g_i, state["g_star"])
|
||||||
|
state["dphi_star"] = mnp.where(cond2, dphi_i, state["dphi_star"])
|
||||||
|
|
||||||
|
# satisfying the strong wolf conditions
|
||||||
|
cond3 = mnp.logical_not(cond1) and mnp.logical_not(cond2) and dphi_i >= 0.
|
||||||
|
zoom2 = _zoom(fval_and_grad, a_i, phi_i, dphi_i, state["a_i"], state["phi_i"],
|
||||||
|
state["dphi_i"], phi_0, g_0, dphi_0, c1, c2, cond3)
|
||||||
|
state["nfev"] += zoom2["nfev"]
|
||||||
|
state["ngev"] += zoom2["ngev"]
|
||||||
|
state["done"] = state["done"] or cond3
|
||||||
|
state["failed"] = state["failed"] or (cond3 and zoom2["failed"])
|
||||||
|
state["a_star"] = mnp.where(cond3, zoom2["a_star"], state["a_star"])
|
||||||
|
state["phi_star"] = mnp.where(cond3, zoom2["phi_star"], state["phi_star"])
|
||||||
|
state["g_star"] = mnp.where(cond3, zoom2["g_star"], state["g_star"])
|
||||||
|
state["dphi_star"] = mnp.where(cond3, zoom2["dphi_star"], state["dphi_star"])
|
||||||
|
|
||||||
|
# next iteration
|
||||||
|
state["i"] += 1
|
||||||
|
state["a_i"] = a_i
|
||||||
|
state["phi_i"] = phi_i
|
||||||
|
state["dphi_i"] = dphi_i
|
||||||
|
|
||||||
|
state["status"] = mnp.where(
|
||||||
|
state["failed"],
|
||||||
|
1, # zoom failed
|
||||||
|
mnp.where(
|
||||||
|
state["i"] > maxiter,
|
||||||
|
3, # maxiter reached
|
||||||
|
0, # passed (should be)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
state["a_star"] = mnp.where(
|
||||||
|
_to_tensor(state["a_star"].dtype != mstype.float64)
|
||||||
|
and (mnp.abs(state["a_star"]) < 1e-8),
|
||||||
|
mnp.sign(state["a_star"]) * 1e-8,
|
||||||
|
state["a_star"],
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||||
|
c2=0.9, maxiter=20) -> _LineSearchResults:
|
||||||
|
"""Inexact line search that satisfies strong Wolfe conditions.
|
||||||
|
|
||||||
|
Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (function): function of the form f(x) where x is a flat Tensor and returns a real
|
||||||
|
scalar. The function should be composed of operations with vjp defined.
|
||||||
|
x0 (Tensor): initial guess.
|
||||||
|
pk (Tensor): direction to search in. Assumes the direction is a descent direction.
|
||||||
|
old_fval, gfk (Tensor): initial value of value_and_gradient as position.
|
||||||
|
old_old_fval (Tensor): unused argument, only for scipy API compliance.
|
||||||
|
maxiter (int): maximum number of iterations to search
|
||||||
|
c1, c2 (float): Wolfe criteria constant, see ref.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LineSearchResults
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU`` ``GPU``
|
||||||
|
"""
|
||||||
|
state = LineSearch(f)(xk, pk, old_fval, old_old_fval, gfk, c1, c2, maxiter)
|
||||||
|
# If running in graph mode, the state is a tuple.
|
||||||
|
if isinstance(state, tuple):
|
||||||
|
state = _LineSearchResults(failed=_to_scalar(state[0] or not state[1]),
|
||||||
|
nit=_to_scalar(state[2] - 1),
|
||||||
|
nfev=_to_scalar(state[6]),
|
||||||
|
ngev=_to_scalar(state[7]),
|
||||||
|
k=_to_scalar(state[2]),
|
||||||
|
a_k=_to_scalar(state[8]),
|
||||||
|
f_k=_to_scalar(state[9]),
|
||||||
|
g_k=state[11],
|
||||||
|
status=_to_scalar(state[12]))
|
||||||
|
else:
|
||||||
|
state = _LineSearchResults(failed=_to_scalar(state["failed"] or not state["done"]),
|
||||||
|
nit=_to_scalar(state["i"] - 1),
|
||||||
|
nfev=_to_scalar(state["nfev"]),
|
||||||
|
ngev=_to_scalar(state["ngev"]),
|
||||||
|
k=_to_scalar(state["i"]),
|
||||||
|
a_k=_to_scalar(state["a_star"]),
|
||||||
|
f_k=_to_scalar(state["phi_star"]),
|
||||||
|
g_k=state["g_star"],
|
||||||
|
status=_to_scalar(state["status"]))
|
||||||
|
|
||||||
|
return state
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""minimize"""
|
||||||
|
from typing import Optional
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from ...common import Tensor
|
||||||
|
|
||||||
|
from ._bfgs import minimize_bfgs
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizeResults(NamedTuple):
|
||||||
|
"""Object holding optimization results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): final solution.
|
||||||
|
success (bool): ``True`` if optimization succeeded.
|
||||||
|
status (int): solver specific return code. 0 means converged (nominal),
|
||||||
|
1=max BFGS iters reached, 3=zoom failed, 4=saddle point reached,
|
||||||
|
5=max line search iters reached, -1=undefined
|
||||||
|
fun (float): final function value.
|
||||||
|
jac (Tensor): final jacobian array.
|
||||||
|
hess_inv (Tensor, optional): final inverse Hessian estimate.
|
||||||
|
nfev (int): number of function calls used.
|
||||||
|
njev (int): number of gradient evaluations.
|
||||||
|
nit (int): number of iterations of the optimization algorithm.
|
||||||
|
"""
|
||||||
|
x: Tensor
|
||||||
|
success: bool
|
||||||
|
status: int
|
||||||
|
fun: float
|
||||||
|
jac: Tensor
|
||||||
|
hess_inv: Optional[Tensor]
|
||||||
|
nfev: int
|
||||||
|
njev: int
|
||||||
|
nit: int
|
||||||
|
|
||||||
|
|
||||||
|
def minimize(func, x0, args=(), *, method, tol=None, options=None) -> OptimizeResults:
|
||||||
|
"""Minimization of scalar function of one or more variables.
|
||||||
|
|
||||||
|
This API for this function matches SciPy with some minor deviations:
|
||||||
|
|
||||||
|
- Gradients of ``fun`` are calculated automatically using MindSpore's autodiff
|
||||||
|
support when required.
|
||||||
|
- The ``method`` argument is required. You must specify a solver.
|
||||||
|
- Various optional arguments in the SciPy interface have not yet been
|
||||||
|
implemented.
|
||||||
|
- Optimization results may differ from SciPy due to differences in the line
|
||||||
|
search implementation.
|
||||||
|
|
||||||
|
It does not yet support differentiation or arguments in the form of
|
||||||
|
multi-dimensional Tensor, but support for both is planned.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
On CPU, the supported dtypes is float32.
|
||||||
|
On GPU, the supported dtypes is float32.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fun (Callable): the objective function to be minimized, ``fun(x, *args) -> float``,
|
||||||
|
where ``x`` is a 1-D array with shape ``(n,)`` and ``args`` is a tuple
|
||||||
|
of the fixed parameters needed to completely specify the function.
|
||||||
|
``fun`` must support differentiation.
|
||||||
|
x0 (Tensor): initial guess. Array of real elements of size ``(n,)``, where ``n`` is
|
||||||
|
the number of independent variables.
|
||||||
|
args (Tuple): extra arguments passed to the objective function.
|
||||||
|
method (str): solver type. Currently only ``"BFGS"`` is supported.
|
||||||
|
tol (float, optional): tolerance for termination. For detailed control, use solver-specific
|
||||||
|
options.
|
||||||
|
options (Mapping[str, Any], optional): a dictionary of solver options. All methods accept the following
|
||||||
|
generic options:
|
||||||
|
|
||||||
|
- maxiter (int): Maximum number of iterations to perform. Depending on the
|
||||||
|
method each iteration may use several function evaluations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An (OptimizeResults): class:`OptimizeResults` object.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as onp
|
||||||
|
>>> from mindspore.scipy.optimize import minimize
|
||||||
|
>>> from mindspore.common import Tensor
|
||||||
|
>>> x0 = Tensor(onp.zeros(2).astype(onp.float32))
|
||||||
|
>>> def func(p):
|
||||||
|
>>> x, y = p
|
||||||
|
>>> return (x ** 2 + y - 11.) ** 2 + (x + y ** 2 - 7.) ** 2
|
||||||
|
>>> res = minimize(func, x0, method='BFGS', options=dict(maxiter=None, gtol=1e-6))
|
||||||
|
>>> res.x
|
||||||
|
[3. 2.]
|
||||||
|
"""
|
||||||
|
if options is None:
|
||||||
|
options = {}
|
||||||
|
|
||||||
|
if not isinstance(args, tuple):
|
||||||
|
msg = "args argument to mindspore.scipy.optimize.minimize must be a tuple, got {}"
|
||||||
|
raise TypeError(msg.format(args))
|
||||||
|
|
||||||
|
def fun_with_args(args):
|
||||||
|
def inner_func(x):
|
||||||
|
return func(x, *args)
|
||||||
|
|
||||||
|
return inner_func
|
||||||
|
|
||||||
|
if method.lower() == 'bfgs':
|
||||||
|
results = minimize_bfgs(fun_with_args(args), x0, **options)
|
||||||
|
success = results.converged and results.failed
|
||||||
|
return OptimizeResults(x=results.x_k,
|
||||||
|
success=success,
|
||||||
|
status=results.status,
|
||||||
|
fun=results.f_k,
|
||||||
|
jac=results.g_k,
|
||||||
|
hess_inv=results.H_k,
|
||||||
|
nfev=results.nfev,
|
||||||
|
njev=results.ngev,
|
||||||
|
nit=results.k)
|
||||||
|
|
||||||
|
raise ValueError("Method {} not recognized".format(method))
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""internal utility functions"""
|
||||||
|
import numpy as onp
|
||||||
|
|
||||||
|
from ..common import Tensor
|
||||||
|
from ..common import dtype as mstype
|
||||||
|
from .utils_const import _type_convert, _raise_type_error
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_64_to_32(tensor):
|
||||||
|
"""Convert tensor with float64/int64 types to float32/int32."""
|
||||||
|
if tensor.dtype == mstype.float64:
|
||||||
|
return tensor.astype("float32")
|
||||||
|
if tensor.dtype == mstype.int64:
|
||||||
|
return tensor.astype("int32")
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tensor(*args):
|
||||||
|
"""Returns each input as Tensor"""
|
||||||
|
res = ()
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, (int, float, bool, list, tuple)):
|
||||||
|
arg = _convert_64_to_32(_type_convert(Tensor, arg))
|
||||||
|
elif not isinstance(arg, Tensor):
|
||||||
|
_raise_type_error("Expect input to be array like.")
|
||||||
|
res += (arg,)
|
||||||
|
if len(res) == 1:
|
||||||
|
return res[0]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _to_scalar(arr):
|
||||||
|
"""Convert a scalar Tensor or ndarray to a scalar."""
|
||||||
|
if isinstance(arr, (int, float, bool)):
|
||||||
|
return arr
|
||||||
|
if isinstance(arr, Tensor):
|
||||||
|
if arr.shape:
|
||||||
|
return arr
|
||||||
|
arr = arr.asnumpy()
|
||||||
|
if isinstance(arr, onp.ndarray):
|
||||||
|
if arr.shape:
|
||||||
|
return arr
|
||||||
|
return arr.item()
|
||||||
|
raise ValueError("{} are not supported.".format(type(arr)))
|
||||||
|
|
||||||
|
|
||||||
|
_FLOAT_ONE = _to_tensor(1.0)
|
||||||
|
_FLOAT_ZERO = _to_tensor(0.0)
|
||||||
|
_INT_ZERO = _to_tensor(0)
|
||||||
|
_INT_ONE = _to_tensor(1)
|
||||||
|
_BOOL_TRUE = _to_tensor(True)
|
||||||
|
_BOOL_FALSE = _to_tensor(False)
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""internal graph-compatible utility functions"""
|
||||||
|
from ..ops.primitive import constexpr
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _type_convert(new_type, obj):
|
||||||
|
"""
|
||||||
|
Convert type of `obj` to `force`.
|
||||||
|
"""
|
||||||
|
return new_type(obj)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _raise_type_error(info, param=None):
|
||||||
|
"""
|
||||||
|
Raise TypeError in both graph/pynative mode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info(str): info string to display
|
||||||
|
param(python obj): any object that can be recognized by graph mode. If is
|
||||||
|
not None, then param's type information will be extracted and displayed.
|
||||||
|
Default is None.
|
||||||
|
"""
|
||||||
|
if param is None:
|
||||||
|
raise TypeError(info)
|
||||||
|
raise TypeError(info + f"{type(param)}")
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""setup for pytest in mindspore.scipy"""
|
||||||
|
import mindspore.context as context
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def setup_module(module):
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""st for linalg."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as onp
|
||||||
|
import scipy as osp
|
||||||
|
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
import mindspore.scipy as msp
|
||||||
|
from .utils import match_array
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('args', [(), (1,), (7, -1), (3, 4, 5),
|
||||||
|
(onp.ones((3, 4), dtype=onp.float32), 5, onp.random.randn(5, 2).astype(onp.float32))])
|
||||||
|
def test_block_diag(args):
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for block_diag
|
||||||
|
Expectation: the result match scipy
|
||||||
|
"""
|
||||||
|
tensor_args = tuple([Tensor(arg) for arg in args])
|
||||||
|
ms_res = msp.linalg.block_diag(*tensor_args)
|
||||||
|
|
||||||
|
scipy_res = osp.linalg.block_diag(*args)
|
||||||
|
match_array(ms_res.asnumpy(), scipy_res)
|
|
@ -0,0 +1,142 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""st for line_search."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import numpy as onp
|
||||||
|
from scipy.optimize.linesearch import line_search_wolfe2 as osp_line_search
|
||||||
|
|
||||||
|
import mindspore.numpy as mnp
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.scipy.optimize.line_search import line_search as msp_line_search
|
||||||
|
|
||||||
|
from .utils import match_array
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar_func_1(np):
|
||||||
|
def f(x):
|
||||||
|
return -x - x ** 3 + x ** 4
|
||||||
|
|
||||||
|
def fprime(x):
|
||||||
|
return -1 - 3 * x ** 2 + 4 * x ** 3
|
||||||
|
|
||||||
|
return f, fprime
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar_func_2(np):
|
||||||
|
def f(x):
|
||||||
|
return np.exp(-4 * x) + x ** 2
|
||||||
|
|
||||||
|
def fprime(x):
|
||||||
|
return -4 * np.exp(-4 * x) + 2 * x
|
||||||
|
|
||||||
|
return f, fprime
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar_func_3(np):
|
||||||
|
def f(x):
|
||||||
|
return -np.sin(10 * x)
|
||||||
|
|
||||||
|
def fprime(x):
|
||||||
|
return -10 * np.cos(10 * x)
|
||||||
|
|
||||||
|
return f, fprime
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('maxiter, func, x, p', [(10, _scalar_func_1, 0., 1.),
|
||||||
|
(10, _scalar_func_2, 0., 1.),
|
||||||
|
(10, _scalar_func_3, 0., 1.)])
|
||||||
|
def test_scalar_search(maxiter, func, x, p):
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for 1-d function
|
||||||
|
Expectation: the result match scipy
|
||||||
|
"""
|
||||||
|
osp_f, osp_fp = func(onp)
|
||||||
|
osp_x, osp_p = onp.array(x), onp.array(p)
|
||||||
|
osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter)
|
||||||
|
|
||||||
|
msp_f, _ = func(mnp)
|
||||||
|
msp_x, msp_p = mnp.array(x), mnp.array(p)
|
||||||
|
msp_res = msp_line_search(msp_f, msp_x, msp_p, maxiter=maxiter)
|
||||||
|
|
||||||
|
match_array(msp_res.a_k, osp_res[0], error=5)
|
||||||
|
match_array(msp_res.f_k, osp_res[3], error=5)
|
||||||
|
|
||||||
|
|
||||||
|
def _line_func_1(np, *args):
|
||||||
|
def f(x):
|
||||||
|
return np.dot(x, x)
|
||||||
|
|
||||||
|
def fprime(x):
|
||||||
|
return 2 * x
|
||||||
|
|
||||||
|
return f, fprime
|
||||||
|
|
||||||
|
|
||||||
|
def _line_func_2(np, *args):
|
||||||
|
def f(x):
|
||||||
|
A = args[0]
|
||||||
|
return np.dot(x, np.dot(A, x)) + 1
|
||||||
|
|
||||||
|
def fprime(x):
|
||||||
|
A = args[0]
|
||||||
|
return np.dot(A + A.T, x)
|
||||||
|
|
||||||
|
return f, fprime
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('maxiter, func, x, p',
|
||||||
|
[(10, _line_func_1, [1.13689136, 0.09772497, 0.58295368, -0.39944903, 0.37005589],
|
||||||
|
[-1.30652685, 1.65813068, -0.11816405, -0.6801782, 0.66638308]),
|
||||||
|
(10, _line_func_1, [-0.52118931, -1.84306955, -0.477974, -0.47965581, 0.6203583],
|
||||||
|
[0.69845715, 0.00377089, 0.93184837, 0.33996498, -0.01568211]),
|
||||||
|
(10, _line_func_2, [0.15634897, 1.23029068, 1.20237985, -0.38732682, -0.30230275],
|
||||||
|
[-1.04855297, -1.42001794, -1.70627019, 1.9507754, -0.50965218]),
|
||||||
|
(10, _line_func_2, [0.42833187, 0.06651722, 0.3024719, -0.63432209, -0.36274117],
|
||||||
|
[-0.67246045, -0.35955316, -0.81314628, -1.7262826, 0.17742614])])
|
||||||
|
def test_line_search(maxiter, func, x, p):
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for n-d function
|
||||||
|
Expectation: the result match scipy
|
||||||
|
"""
|
||||||
|
A = [[1.76405235, 0.40015721, 0.97873798, 2.2408932, 1.86755799],
|
||||||
|
[-0.97727788, 0.95008842, -0.15135721, -0.10321885, 0.4105985],
|
||||||
|
[0.14404357, 1.45427351, 0.76103773, 0.12167502, 0.44386323],
|
||||||
|
[0.33367433, 1.49407907, -0.20515826, 0.3130677, -0.85409574],
|
||||||
|
[-2.55298982, 0.6536186, 0.8644362, -0.74216502, 2.26975462]]
|
||||||
|
|
||||||
|
osp_x, osp_p, osp_A = onp.array(x), onp.array(p), onp.array(A)
|
||||||
|
osp_f, osp_fp = func(onp, osp_A)
|
||||||
|
osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter)
|
||||||
|
|
||||||
|
msp_x, msp_p, msp_A = mnp.array(x), mnp.array(p), mnp.array(A)
|
||||||
|
msp_f, _ = func(mnp, msp_A)
|
||||||
|
msp_res = msp_line_search(msp_f, msp_x, msp_p, maxiter=maxiter)
|
||||||
|
|
||||||
|
match_array(msp_res.a_k, osp_res[0], error=5)
|
||||||
|
match_array(msp_res.f_k, osp_res[3], error=5)
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""st for optimize."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as onp
|
||||||
|
import scipy as osp
|
||||||
|
|
||||||
|
import mindspore.numpy as mnp
|
||||||
|
import mindspore.scipy as msp
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import Tensor
|
||||||
|
|
||||||
|
from .utils import match_array
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
def rosenbrock(np):
|
||||||
|
def func(x):
|
||||||
|
return np.sum(100. * np.diff(x) ** 2 + (1. - x[:-1]) ** 2)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def himmelblau(np):
|
||||||
|
def func(p):
|
||||||
|
x, y = p
|
||||||
|
return (x ** 2 + y - 11.) ** 2 + (x + y ** 2 - 7.) ** 2
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def matyas(np):
|
||||||
|
def func(p):
|
||||||
|
x, y = p
|
||||||
|
return 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def eggholder(np):
|
||||||
|
def func(p):
|
||||||
|
x, y = p
|
||||||
|
return - (y + 47) * np.sin(np.sqrt(np.abs(x / 2. + y + 47.))) - x * np.sin(
|
||||||
|
np.sqrt(np.abs(x - (y + 47.))))
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('dtype', [onp.float32])
|
||||||
|
@pytest.mark.parametrize('func_x0', [(rosenbrock, onp.zeros(2)),
|
||||||
|
(himmelblau, onp.zeros(2)),
|
||||||
|
(himmelblau, onp.array([92, 0.001])),
|
||||||
|
(matyas, onp.ones(2)),
|
||||||
|
(eggholder, onp.ones(2) * 100.)])
|
||||||
|
def test_bfgs(dtype, func_x0):
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for bfgs
|
||||||
|
Expectation: the result match scipy
|
||||||
|
"""
|
||||||
|
func, x0 = func_x0
|
||||||
|
x0 = x0.astype(dtype)
|
||||||
|
x0_tensor = Tensor(x0)
|
||||||
|
ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
|
||||||
|
options=dict(maxiter=None, gtol=1e-6)).x
|
||||||
|
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS').x
|
||||||
|
match_array(ms_res.asnumpy(), scipy_res, error=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('dtype', [onp.float32])
|
||||||
|
def test_bfgs_fixes4594(dtype):
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for bfgs
|
||||||
|
Expectation: the result match scipy
|
||||||
|
"""
|
||||||
|
n = 2
|
||||||
|
A = Tensor(onp.eye(n, dtype=dtype)) * 1e4
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
return mnp.mean((mnp.dot(A, x)) ** 2)
|
||||||
|
|
||||||
|
results = msp.optimize.minimize(func, Tensor(onp.ones(n, dtype=dtype)), method='BFGS',
|
||||||
|
options=dict(maxiter=None, gtol=1e-6)).x
|
||||||
|
onp.testing.assert_allclose(results.asnumpy(), onp.zeros(n, dtype=dtype), rtol=1e-6, atol=1e-6)
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""utility functions for mindspore.scipy st tests"""
|
||||||
|
import numpy as onp
|
||||||
|
from mindspore import Tensor
|
||||||
|
import mindspore.numpy as mnp
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor(obj, dtype=None):
|
||||||
|
if dtype is None:
|
||||||
|
res = Tensor(obj)
|
||||||
|
if res.dtype == mnp.float64:
|
||||||
|
res = res.astype(mnp.float32)
|
||||||
|
if res.dtype == mnp.int64:
|
||||||
|
res = res.astype(mnp.int32)
|
||||||
|
else:
|
||||||
|
res = Tensor(obj, dtype)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def match_array(actual, expected, error=0):
|
||||||
|
if isinstance(actual, int):
|
||||||
|
actual = onp.asarray(actual)
|
||||||
|
|
||||||
|
if isinstance(expected, (int, tuple)):
|
||||||
|
expected = onp.asarray(expected)
|
||||||
|
|
||||||
|
if error > 0:
|
||||||
|
onp.testing.assert_almost_equal(actual, expected, decimal=error)
|
||||||
|
else:
|
||||||
|
onp.testing.assert_equal(actual, expected)
|
Loading…
Reference in New Issue