diff --git a/cmake/package.cmake b/cmake/package.cmake index 6f5e3594797..dc3585ebcfd 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -278,6 +278,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/parallel ${CMAKE_SOURCE_DIR}/mindspore/mindrecord ${CMAKE_SOURCE_DIR}/mindspore/numpy + ${CMAKE_SOURCE_DIR}/mindspore/scipy ${CMAKE_SOURCE_DIR}/mindspore/train ${CMAKE_SOURCE_DIR}/mindspore/boost ${CMAKE_SOURCE_DIR}/mindspore/common diff --git a/mindspore/scipy/__init__.py b/mindspore/scipy/__init__.py index e9a2f4d1f01..26adff44d55 100644 --- a/mindspore/scipy/__init__.py +++ b/mindspore/scipy/__init__.py @@ -13,3 +13,13 @@ # limitations under the License. # ============================================================================ """Scipy-like interfaces in mindspore.""" + +from . import optimize, linalg +from .optimize import * +from .linalg import * + +__all__ = [] +__all__.extend(optimize.__all__) +__all__.extend(linalg.__all__) + +__all__.sort() diff --git a/mindspore/scipy/linalg.py b/mindspore/scipy/linalg.py new file mode 100755 index 00000000000..442e26f9527 --- /dev/null +++ b/mindspore/scipy/linalg.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Linear algebra submodule""" +from .. import numpy as mnp +from .. import ops + +__all__ = ['block_diag'] + + +def block_diag(*arrs): + """ + Create a block diagonal matrix from provided arrays. + + Given the inputs `A`, `B` and `C`, the output will have these + Tensor arranged on the diagonal:: + + [[A, 0, 0], + [0, B, 0], + [0, 0, C]] + + Args: + A, B, C, ... (Tensor): up to 2-D + Input Tensors. A 1-D Tensor or a 2-D Tensor with shape ``(1,n)``. + + Returns: + D (Tesnor): Tensor with `A`, `B`, `C`, ... on the diagonal. `D` has + the same dtype as `A`. + + Raises: + ValueError: If there are tensors with dimensions higher than 2 in all arguments. + + Supported Platforms: + ``CPU`` ``GPU`` + + Examples: + >>> import numpy as onp + >>> from mindspore.common import Tensor + >>> from mindspore.scipy.linalg import block_diag + >>> A = Tensor(onp.array([[1, 0], [0, 1]])) + >>> B = Tensor(onp.array([[3, 4, 5], [6, 7, 8]])) + >>> C = Tensor(onp.array([[7]])) + >>> P = Tensor(onp.zeros((2, ), dtype='int32')) + >>> block_diag(A, B, C) + [[1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 3, 4, 5, 0], + [0, 0, 6, 7, 8, 0], + [0, 0, 0, 0, 0, 7]] + >>> block_diag(A, P, B, C) + [[1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 3, 4, 5, 0], + [0, 0, 6, 7, 8, 0], + [0, 0, 0, 0, 0, 7]] + """ + if not arrs: + return mnp.zeros((1, 0)) + bad_shapes = [i for i, a in enumerate(arrs) if a.ndim > 2] + if bad_shapes: + raise ValueError("Arguments to mindspore.scipy.linalg.block_diag must have at " + "most 2 dimensions, got {} at argument {}." + .format(arrs[bad_shapes[0]], bad_shapes[0])) + arrs = [mnp.atleast_2d(a) for a in arrs] + accum = arrs[0] + for arr in arrs[1:]: + _, c = arr.shape + arr = ops.Pad(((0, 0), (accum.shape[-1], 0)))(arr) + accum = ops.Pad(((0, 0), (0, c)))(accum) + accum = mnp.concatenate([accum, arr], axis=0) + return accum diff --git a/mindspore/scipy/optimize/__init__.py b/mindspore/scipy/optimize/__init__.py new file mode 100644 index 00000000000..ee80dcf9d5e --- /dev/null +++ b/mindspore/scipy/optimize/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimize submodule""" +from .minimize import minimize +from .line_search import line_search + +__all__ = ["minimize", "line_search"] diff --git a/mindspore/scipy/optimize/_bfgs.py b/mindspore/scipy/optimize/_bfgs.py new file mode 100644 index 00000000000..01517bfdfa2 --- /dev/null +++ b/mindspore/scipy/optimize/_bfgs.py @@ -0,0 +1,228 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""BFGS""" +from typing import NamedTuple + +from ... import nn +from ... import numpy as mnp +from ...common import Tensor +from ...ops import functional as F + +from .line_search import LineSearch +from ..utils import _to_scalar +from ..utils import _INT_ZERO, _INT_ONE, _BOOL_FALSE + + +class _BFGSResults(NamedTuple): + """Results from BFGS optimization. + + Arg + converged (bool): `True`` if minimization converged. + failed (bool): `True`` if line search failed. + k (int): the number of iterations of the BFGS update. + nfev (int): total number of objective evaluations performed. + ngev (int): total number of jacobian evaluations + nhev (int): total number of hessian evaluations + x_k (Tensor): containing the last argument value found during the search. If + the search converged, then this value is the argmin of the objective + function. + f_k (float): containing the value of the objective function at `x_k`. If the + search converged, then this is the (local) minimum of the objective + function. + g_k (Tensor): containing the gradient of the objective function at `x_k`. If + the search converged the l2-norm of this tensor should be below the + tolerance. + H_k (Tensor): containing the inverse of the estimated Hessian. + old_old_fval (float): Function value for the point preceding x=x_k. + status (int): describing end state. + line_search_status (int): describing line search end state (only means + something if line search fails). + """ + converged: bool + failed: bool + k: int + nfev: int + ngev: int + nhev: int + x_k: Tensor + f_k: float + g_k: Tensor + H_k: Tensor + old_old_fval: float + status: int + line_search_status: int + + +class MinimizeBfgs(nn.Cell): + """minimize bfgs""" + + def __init__(self, func): + """Initialize MinimizeBfgs.""" + super(MinimizeBfgs, self).__init__() + self.func = func + self.line_search = LineSearch(func) + + def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10): + def _my_norm(x, ord_=None): + if ord_ == mnp.inf: + res = mnp.max(mnp.abs(x)) + else: + res = mnp.sqrt(mnp.sum(x ** 2)) + return res + + if maxiter is None: + maxiter = mnp.size(x0) * 200 + + d = x0.shape[0] + initial_H = mnp.eye(d, dtype=x0.dtype) + f_0 = self.func(x0) + g_0 = F.grad(self.func, grad_first_param=True)(x0) + + state = { + "converged": _my_norm(g_0, ord_=mnp.inf) < gtol, + "failed": _BOOL_FALSE, + "k": _INT_ZERO, + "nfev": _INT_ONE, + "ngev": _INT_ONE, + "nhev": _INT_ZERO, + "x_k": x0, + "f_k": f_0, + "g_k": g_0, + "H_k": initial_H, + "old_old_fval": f_0 + _my_norm(g_0) / 2, + "status": _INT_ZERO, + "line_search_status": _INT_ZERO + } + + while state["k"] < maxiter: + p_k = -1 * mnp.dot(state["H_k"], state["g_k"]) + line_search_results = self.line_search(state["x_k"], + p_k, + old_fval=state["f_k"], + old_old_fval=state["old_old_fval"], + gfk=state["g_k"], + maxiter=line_search_maxiter) + state["nfev"] += line_search_results["nfev"] + state["ngev"] += line_search_results["ngev"] + state["failed"] = line_search_results["failed"] or mnp.logical_not(line_search_results["done"]) + state["line_search_status"] = line_search_results["status"] + + if state["failed"]: + break + + s_k = line_search_results["a_star"] * p_k + x_kp1 = state["x_k"] + s_k + f_kp1 = line_search_results["phi_star"] + g_kp1 = line_search_results["g_star"] + y_k = g_kp1 - state["g_k"] + + state["old_old_fval"] = state["f_k"] + state["converged"] = _my_norm(g_kp1, ord_=norm) < gtol + state["x_k"] = x_kp1 + state["f_k"] = f_kp1 + state["g_k"] = g_kp1 + + if state["converged"]: + break + + rho_k = mnp.reciprocal(mnp.dot(y_k, s_k)) + sy_k = mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(y_k, axis=0) + term1 = rho_k * sy_k + sy_k_2 = mnp.expand_dims(y_k, axis=1) * mnp.expand_dims(s_k, axis=0) + term2 = rho_k * sy_k_2 + I = mnp.eye(d) + term3 = mnp.matmul(I - term1, state["H_k"]) + term4 = mnp.matmul(term3, I - term2) + term5 = rho_k * (mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(s_k, axis=0)) + H_kp1 = term4 + term5 + state["H_k"] = H_kp1 + + # next iteration + state["k"] = state["k"] + 1 + + status = mnp.where( + state["converged"], + 0, # converged + mnp.where( + state["k"] == maxiter, + 1, # max iters reached + mnp.where( + state["failed"], + 2 + state["line_search_status"], # ls failed (+ reason) + -1, # undefined + ) + ) + ) + state["status"] = status + return state + + +def minimize_bfgs(func, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10): + """Minimize a function using BFGS. + + Implements the BFGS algorithm from + Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. + 136-143. + + Args: + fun (Callable): function of the form f(x) where x is a flat Tensor and returns a real + scalar. The function should be composed of operations with vjp defined. + x0 (Tensor): initial guess. + maxiter (int, optional): maximum number of iterations. + norm (float): order of norm for convergence check. Default inf. + gtol (float): terminates minimization when |grad|_norm < g_tol. + line_search_maxiter (int): maximum number of linesearch iterations. + + Returns: + Optimization result. + + Supported Platforms: + ``CPU`` ``GPU`` + """ + if maxiter is None: + maxiter = mnp.size(x0) * 200 + + state = MinimizeBfgs(func)(x0, maxiter, norm, gtol) + # If running in graph mode, the state is a tuple. + if isinstance(state, tuple): + state = _BFGSResults(converged=_to_scalar(state[0]), + failed=_to_scalar(state[1]), + k=_to_scalar(state[2]), + nfev=_to_scalar(state[3]), + ngev=_to_scalar(state[4]), + nhev=_to_scalar(state[5]), + x_k=state[6], + f_k=_to_scalar(state[7]), + g_k=state[8], + H_k=state[9], + old_old_fval=_to_scalar(state[10]), + status=_to_scalar(state[11]), + line_search_status=_to_scalar(state[12])) + else: + state = _BFGSResults(converged=_to_scalar(state["converged"]), + failed=_to_scalar(state["failed"]), + k=_to_scalar(state["k"]), + nfev=_to_scalar(state["nfev"]), + ngev=_to_scalar(state["ngev"]), + nhev=_to_scalar(state["nhev"]), + x_k=state["x_k"], + f_k=_to_scalar(state["f_k"]), + g_k=state["g_k"], + H_k=state["H_k"], + old_old_fval=_to_scalar(state["old_old_fval"]), + status=_to_scalar(state["status"]), + line_search_status=_to_scalar(state["line_search_status"])) + + return state diff --git a/mindspore/scipy/optimize/line_search.py b/mindspore/scipy/optimize/line_search.py new file mode 100644 index 00000000000..5171bd9db41 --- /dev/null +++ b/mindspore/scipy/optimize/line_search.py @@ -0,0 +1,343 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""line search""" +from typing import NamedTuple + +from ... import nn +from ... import numpy as mnp +from ...common import dtype as mstype +from ...common import Tensor +from ...ops import functional as F + +from ..utils import _to_scalar +from ..utils import _to_tensor, _FLOAT_ZERO, _FLOAT_ONE, _INT_ZERO, _INT_ONE, _BOOL_FALSE + + +class _LineSearchResults(NamedTuple): + """Results of line search results. + + Args: + failed (bool): `True`` if the strong Wolfe criteria were satisfied + nit (int): number of iterations + nfev (int): number of functions evaluations + ngev (int): number of gradients evaluations + k (int): number of iterations + a_k (float): step size + f_k (float): final function value + g_k (Tensor): final gradient value + status (int): end status + """ + failed: bool + nit: int + nfev: int + ngev: int + k: int + a_k: float + f_k: float + g_k: Tensor + status: int + + +def _cubicmin(a, fa, fpa, b, fb, c, fc): + """Finds the minimizer for a cubic polynomial that goes through the + points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa. + """ + C = fpa + db = b - a + dc = c - a + denom = (db * dc) ** 2 * (db - dc) + + d1 = mnp.zeros((2, 2)) + d1[0, 0] = dc ** 2 + d1[0, 1] = -db ** 2 + d1[1, 0] = -dc ** 3 + d1[1, 1] = db ** 3 + + d2 = mnp.zeros((2,)) + d2[0] = fb - fa - C * db + d2[1] = fc - fa - C * dc + + A, B = mnp.dot(d1, d2.flatten()) / denom + + radical = B * B - 3. * A * C + xmin = a + (-B + mnp.sqrt(radical)) / (3. * A) + return xmin + + +def _quadmin(a, fa, fpa, b, fb): + """Finds the minimizer for a quadratic polynomial that goes through + the points (a,fa), (b,fb) with derivative at a of fpa. + """ + D = fa + C = fpa + db = b - a + B = (fb - D - C * db) / (db ** 2) + xmin = a - C / (2. * B) + return xmin + + +def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0, dphi_0, c1, c2, is_run): + """Implementation of zoom algorithm. + Algorithm 3.6 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61. + Tries cubic, quadratic, and bisection methods of zooming. + """ + state = { + "done": _BOOL_FALSE, + "failed": _BOOL_FALSE, + "j": _INT_ZERO, + "a_low": a_low, + "phi_low": phi_low, + "dphi_low": dphi_low, + "a_high": a_high, + "phi_high": phi_high, + "dphi_high": dphi_high, + "a_rec": (a_low + a_high) / 2., + "phi_rec": (phi_low + phi_high) / 2., + "a_star": _FLOAT_ONE, + "phi_star": phi_low, + "dphi_star": dphi_low, + "g_star": g_0, + "nfev": _INT_ZERO, + "ngev": _INT_ZERO, + } + + if mnp.logical_not(is_run): + return state + + delta1 = 0.2 + delta2 = 0.1 + maxiter = 10 # scipy: 10 jax: 30 + while mnp.logical_not(state["done"]) and state["j"] < maxiter: + dalpha = state["a_high"] - state["a_low"] + a = mnp.minimum(state["a_low"], state["a_high"]) + b = mnp.maximum(state["a_low"], state["a_high"]) + + cchk = delta1 * dalpha + qchk = delta2 * dalpha + + a_j_cubic = _cubicmin(state["a_low"], state["phi_low"], state["dphi_low"], state["a_high"], + state["phi_high"], state["a_rec"], state["phi_rec"]) + use_cubic = state["j"] > 0 and mnp.isfinite(a_j_cubic) and \ + mnp.logical_and(a_j_cubic > a + cchk, a_j_cubic < b - cchk) + + a_j_quad = _quadmin(state["a_low"], state["phi_low"], state["dphi_low"], state["a_high"], + state["phi_high"]) + use_quad = mnp.logical_not(use_cubic) and mnp.isfinite(a_j_quad) and \ + mnp.logical_and(a_j_quad > a + qchk, a_j_quad < b - qchk) + + a_j_bisection = (state["a_low"] + state["a_high"]) / 2.0 + use_bisection = mnp.logical_not(use_cubic) and mnp.logical_not(use_quad) + + a_j = mnp.where(use_cubic, a_j_cubic, state["a_rec"]) + a_j = mnp.where(use_quad, a_j_quad, a_j) + a_j = mnp.where(use_bisection, a_j_bisection, a_j) + + phi_j, g_j, dphi_j = fn(a_j) + state["nfev"] += 1 + state["ngev"] += 1 + + j_to_high = (phi_j > phi_0 + c1 * a_j * dphi_0) or (phi_j >= state["phi_low"]) + state["a_rec"] = mnp.where(j_to_high, state["a_high"], state["a_rec"]) + state["phi_rec"] = mnp.where(j_to_high, state["phi_high"], state["phi_rec"]) + state["a_high"] = mnp.where(j_to_high, a_j, state["a_high"]) + state["phi_high"] = mnp.where(j_to_high, phi_j, state["phi_high"]) + state["dphi_high"] = mnp.where(j_to_high, dphi_j, state["dphi_high"]) + + j_to_star = mnp.logical_not(j_to_high) and mnp.abs(dphi_j) <= -c2 * dphi_0 + state["done"] = j_to_star + state["a_star"] = mnp.where(j_to_star, a_j, state["a_star"]) + state["phi_star"] = mnp.where(j_to_star, phi_j, state["phi_star"]) + state["g_star"] = mnp.where(j_to_star, g_j, state["g_star"]) + state["dphi_star"] = mnp.where(j_to_star, dphi_j, state["dphi_star"]) + + low_to_high = mnp.logical_not(j_to_high) and mnp.logical_not(j_to_star) and \ + dphi_j * (state["a_high"] - state["a_low"]) >= 0. + state["a_rec"] = mnp.where(low_to_high, state["a_high"], state["a_rec"]) + state["phi_rec"] = mnp.where(low_to_high, state["phi_high"], state["phi_rec"]) + state["a_high"] = mnp.where(low_to_high, a_low, state["a_high"]) + state["phi_high"] = mnp.where(low_to_high, phi_low, state["phi_high"]) + state["dphi_high"] = mnp.where(low_to_high, dphi_low, state["dphi_high"]) + + j_to_low = mnp.logical_not(j_to_high) and mnp.logical_not(j_to_star) + state["a_rec"] = mnp.where(j_to_low, state["a_low"], state["a_rec"]) + state["phi_rec"] = mnp.where(j_to_low, state["phi_low"], state["phi_rec"]) + state["a_low"] = mnp.where(j_to_low, a_j, state["a_low"]) + state["phi_low"] = mnp.where(j_to_low, phi_j, state["phi_low"]) + state["dphi_low"] = mnp.where(j_to_low, dphi_j, state["dphi_low"]) + + # next iteration + state["j"] = state["j"] + 1 + + state["failed"] = state["j"] == maxiter + return state + + +class LineSearch(nn.Cell): + """Line Search that satisfies strong Wolfe conditions.""" + + def __init__(self, func): + """Initialize LineSearch.""" + super(LineSearch, self).__init__() + self.func = func + + def construct(self, xk, pk, old_fval=None, old_old_fval=None, gfk=None, + c1=1e-4, c2=0.9, maxiter=3): + def fval_and_grad(alpha): + xkk = xk + alpha * pk + fkk = self.func(xkk) + gkk = F.grad(self.func, grad_first_param=True)(xkk) + return fkk, gkk, mnp.dot(gkk, pk) + + if old_fval is None or gfk is None: + nfev, ngev = _INT_ONE, _INT_ONE + phi_0, g_0, dphi_0 = fval_and_grad(_FLOAT_ZERO) + else: + nfev, ngev = _INT_ZERO, _INT_ZERO + phi_0, g_0 = old_fval, gfk + dphi_0 = mnp.dot(g_0, pk) + + if old_old_fval is None: + start_value = _FLOAT_ONE + else: + old_phi0 = old_old_fval + candidate_start_value = 1.01 * 2 * (phi_0 - old_phi0) / dphi_0 + start_value = mnp.minimum(candidate_start_value, _FLOAT_ONE) + + state = { + "done": _BOOL_FALSE, + "failed": _BOOL_FALSE, + "i": _INT_ONE, + "a_i": _FLOAT_ZERO, + "phi_i": phi_0, + "dphi_i": dphi_0, + "nfev": nfev, + "ngev": ngev, + "a_star": _FLOAT_ZERO, + "phi_star": phi_0, + "dphi_star": dphi_0, + "g_star": g_0, + } + + while mnp.logical_not(state["done"]) and state["i"] <= maxiter: + a_i = mnp.where(state["i"] > 1, state["a_i"] * 2.0, start_value) + phi_i, g_i, dphi_i = fval_and_grad(a_i) + state["nfev"] += 1 + state["ngev"] += 1 + + # Armijo condition + cond1 = (phi_i > phi_0 + c1 * a_i * dphi_0) or \ + (phi_i >= state["phi_i"] and state["i"] > 1) + zoom1 = _zoom(fval_and_grad, state["a_i"], state["phi_i"], state["dphi_i"], + a_i, phi_i, dphi_i, phi_0, g_0, dphi_0, c1, c2, cond1) + state["nfev"] += zoom1["nfev"] + state["ngev"] += zoom1["ngev"] + state["done"] = cond1 + state["failed"] = cond1 and zoom1["failed"] + state["a_star"] = mnp.where(cond1, zoom1["a_star"], state["a_star"]) + state["phi_star"] = mnp.where(cond1, zoom1["phi_star"], state["phi_star"]) + state["g_star"] = mnp.where(cond1, zoom1["g_star"], state["g_star"]) + state["dphi_star"] = mnp.where(cond1, zoom1["dphi_star"], state["dphi_star"]) + + # curvature condition + cond2 = mnp.logical_not(cond1) and mnp.abs(dphi_i) <= -c2 * dphi_0 + state["done"] = state["done"] or cond2 + state["a_star"] = mnp.where(cond2, a_i, state["a_star"]) + state["phi_star"] = mnp.where(cond2, phi_i, state["phi_star"]) + state["g_star"] = mnp.where(cond2, g_i, state["g_star"]) + state["dphi_star"] = mnp.where(cond2, dphi_i, state["dphi_star"]) + + # satisfying the strong wolf conditions + cond3 = mnp.logical_not(cond1) and mnp.logical_not(cond2) and dphi_i >= 0. + zoom2 = _zoom(fval_and_grad, a_i, phi_i, dphi_i, state["a_i"], state["phi_i"], + state["dphi_i"], phi_0, g_0, dphi_0, c1, c2, cond3) + state["nfev"] += zoom2["nfev"] + state["ngev"] += zoom2["ngev"] + state["done"] = state["done"] or cond3 + state["failed"] = state["failed"] or (cond3 and zoom2["failed"]) + state["a_star"] = mnp.where(cond3, zoom2["a_star"], state["a_star"]) + state["phi_star"] = mnp.where(cond3, zoom2["phi_star"], state["phi_star"]) + state["g_star"] = mnp.where(cond3, zoom2["g_star"], state["g_star"]) + state["dphi_star"] = mnp.where(cond3, zoom2["dphi_star"], state["dphi_star"]) + + # next iteration + state["i"] += 1 + state["a_i"] = a_i + state["phi_i"] = phi_i + state["dphi_i"] = dphi_i + + state["status"] = mnp.where( + state["failed"], + 1, # zoom failed + mnp.where( + state["i"] > maxiter, + 3, # maxiter reached + 0, # passed (should be) + ), + ) + state["a_star"] = mnp.where( + _to_tensor(state["a_star"].dtype != mstype.float64) + and (mnp.abs(state["a_star"]) < 1e-8), + mnp.sign(state["a_star"]) * 1e-8, + state["a_star"], + ) + return state + + +def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, + c2=0.9, maxiter=20) -> _LineSearchResults: + """Inexact line search that satisfies strong Wolfe conditions. + + Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61 + + Args: + fun (function): function of the form f(x) where x is a flat Tensor and returns a real + scalar. The function should be composed of operations with vjp defined. + x0 (Tensor): initial guess. + pk (Tensor): direction to search in. Assumes the direction is a descent direction. + old_fval, gfk (Tensor): initial value of value_and_gradient as position. + old_old_fval (Tensor): unused argument, only for scipy API compliance. + maxiter (int): maximum number of iterations to search + c1, c2 (float): Wolfe criteria constant, see ref. + + Returns: + LineSearchResults + + Supported Platforms: + ``CPU`` ``GPU`` + """ + state = LineSearch(f)(xk, pk, old_fval, old_old_fval, gfk, c1, c2, maxiter) + # If running in graph mode, the state is a tuple. + if isinstance(state, tuple): + state = _LineSearchResults(failed=_to_scalar(state[0] or not state[1]), + nit=_to_scalar(state[2] - 1), + nfev=_to_scalar(state[6]), + ngev=_to_scalar(state[7]), + k=_to_scalar(state[2]), + a_k=_to_scalar(state[8]), + f_k=_to_scalar(state[9]), + g_k=state[11], + status=_to_scalar(state[12])) + else: + state = _LineSearchResults(failed=_to_scalar(state["failed"] or not state["done"]), + nit=_to_scalar(state["i"] - 1), + nfev=_to_scalar(state["nfev"]), + ngev=_to_scalar(state["ngev"]), + k=_to_scalar(state["i"]), + a_k=_to_scalar(state["a_star"]), + f_k=_to_scalar(state["phi_star"]), + g_k=state["g_star"], + status=_to_scalar(state["status"])) + + return state diff --git a/mindspore/scipy/optimize/minimize.py b/mindspore/scipy/optimize/minimize.py new file mode 100644 index 00000000000..d5580689494 --- /dev/null +++ b/mindspore/scipy/optimize/minimize.py @@ -0,0 +1,132 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""minimize""" +from typing import Optional +from typing import NamedTuple + +from ...common import Tensor + +from ._bfgs import minimize_bfgs + + +class OptimizeResults(NamedTuple): + """Object holding optimization results. + + Args: + x (Tensor): final solution. + success (bool): ``True`` if optimization succeeded. + status (int): solver specific return code. 0 means converged (nominal), + 1=max BFGS iters reached, 3=zoom failed, 4=saddle point reached, + 5=max line search iters reached, -1=undefined + fun (float): final function value. + jac (Tensor): final jacobian array. + hess_inv (Tensor, optional): final inverse Hessian estimate. + nfev (int): number of function calls used. + njev (int): number of gradient evaluations. + nit (int): number of iterations of the optimization algorithm. + """ + x: Tensor + success: bool + status: int + fun: float + jac: Tensor + hess_inv: Optional[Tensor] + nfev: int + njev: int + nit: int + + +def minimize(func, x0, args=(), *, method, tol=None, options=None) -> OptimizeResults: + """Minimization of scalar function of one or more variables. + + This API for this function matches SciPy with some minor deviations: + + - Gradients of ``fun`` are calculated automatically using MindSpore's autodiff + support when required. + - The ``method`` argument is required. You must specify a solver. + - Various optional arguments in the SciPy interface have not yet been + implemented. + - Optimization results may differ from SciPy due to differences in the line + search implementation. + + It does not yet support differentiation or arguments in the form of + multi-dimensional Tensor, but support for both is planned. + + Note: + On CPU, the supported dtypes is float32. + On GPU, the supported dtypes is float32. + + Args: + fun (Callable): the objective function to be minimized, ``fun(x, *args) -> float``, + where ``x`` is a 1-D array with shape ``(n,)`` and ``args`` is a tuple + of the fixed parameters needed to completely specify the function. + ``fun`` must support differentiation. + x0 (Tensor): initial guess. Array of real elements of size ``(n,)``, where ``n`` is + the number of independent variables. + args (Tuple): extra arguments passed to the objective function. + method (str): solver type. Currently only ``"BFGS"`` is supported. + tol (float, optional): tolerance for termination. For detailed control, use solver-specific + options. + options (Mapping[str, Any], optional): a dictionary of solver options. All methods accept the following + generic options: + + - maxiter (int): Maximum number of iterations to perform. Depending on the + method each iteration may use several function evaluations. + + Returns: + An (OptimizeResults): class:`OptimizeResults` object. + + Supported Platforms: + ``CPU`` ``GPU`` + + Examples: + >>> import numpy as onp + >>> from mindspore.scipy.optimize import minimize + >>> from mindspore.common import Tensor + >>> x0 = Tensor(onp.zeros(2).astype(onp.float32)) + >>> def func(p): + >>> x, y = p + >>> return (x ** 2 + y - 11.) ** 2 + (x + y ** 2 - 7.) ** 2 + >>> res = minimize(func, x0, method='BFGS', options=dict(maxiter=None, gtol=1e-6)) + >>> res.x + [3. 2.] + """ + if options is None: + options = {} + + if not isinstance(args, tuple): + msg = "args argument to mindspore.scipy.optimize.minimize must be a tuple, got {}" + raise TypeError(msg.format(args)) + + def fun_with_args(args): + def inner_func(x): + return func(x, *args) + + return inner_func + + if method.lower() == 'bfgs': + results = minimize_bfgs(fun_with_args(args), x0, **options) + success = results.converged and results.failed + return OptimizeResults(x=results.x_k, + success=success, + status=results.status, + fun=results.f_k, + jac=results.g_k, + hess_inv=results.H_k, + nfev=results.nfev, + njev=results.ngev, + nit=results.k) + + raise ValueError("Method {} not recognized".format(method)) diff --git a/mindspore/scipy/utils.py b/mindspore/scipy/utils.py new file mode 100644 index 00000000000..3d1c59d8ba1 --- /dev/null +++ b/mindspore/scipy/utils.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""internal utility functions""" +import numpy as onp + +from ..common import Tensor +from ..common import dtype as mstype +from .utils_const import _type_convert, _raise_type_error + + +def _convert_64_to_32(tensor): + """Convert tensor with float64/int64 types to float32/int32.""" + if tensor.dtype == mstype.float64: + return tensor.astype("float32") + if tensor.dtype == mstype.int64: + return tensor.astype("int32") + return tensor + + +def _to_tensor(*args): + """Returns each input as Tensor""" + res = () + for arg in args: + if isinstance(arg, (int, float, bool, list, tuple)): + arg = _convert_64_to_32(_type_convert(Tensor, arg)) + elif not isinstance(arg, Tensor): + _raise_type_error("Expect input to be array like.") + res += (arg,) + if len(res) == 1: + return res[0] + return res + + +def _to_scalar(arr): + """Convert a scalar Tensor or ndarray to a scalar.""" + if isinstance(arr, (int, float, bool)): + return arr + if isinstance(arr, Tensor): + if arr.shape: + return arr + arr = arr.asnumpy() + if isinstance(arr, onp.ndarray): + if arr.shape: + return arr + return arr.item() + raise ValueError("{} are not supported.".format(type(arr))) + + +_FLOAT_ONE = _to_tensor(1.0) +_FLOAT_ZERO = _to_tensor(0.0) +_INT_ZERO = _to_tensor(0) +_INT_ONE = _to_tensor(1) +_BOOL_TRUE = _to_tensor(True) +_BOOL_FALSE = _to_tensor(False) diff --git a/mindspore/scipy/utils_const.py b/mindspore/scipy/utils_const.py new file mode 100644 index 00000000000..739baa24fdc --- /dev/null +++ b/mindspore/scipy/utils_const.py @@ -0,0 +1,40 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""internal graph-compatible utility functions""" +from ..ops.primitive import constexpr + + +@constexpr +def _type_convert(new_type, obj): + """ + Convert type of `obj` to `force`. + """ + return new_type(obj) + + +@constexpr +def _raise_type_error(info, param=None): + """ + Raise TypeError in both graph/pynative mode + + Args: + info(str): info string to display + param(python obj): any object that can be recognized by graph mode. If is + not None, then param's type information will be extracted and displayed. + Default is None. + """ + if param is None: + raise TypeError(info) + raise TypeError(info + f"{type(param)}") diff --git a/tests/st/scipy_st/__init__.py b/tests/st/scipy_st/__init__.py new file mode 100644 index 00000000000..db2da901ca1 --- /dev/null +++ b/tests/st/scipy_st/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""setup for pytest in mindspore.scipy""" +import mindspore.context as context + + +# pylint: disable=unused-argument +def setup_module(module): + context.set_context(mode=context.PYNATIVE_MODE) diff --git a/tests/st/scipy_st/test_linalg.py b/tests/st/scipy_st/test_linalg.py new file mode 100644 index 00000000000..6af6a5776f4 --- /dev/null +++ b/tests/st/scipy_st/test_linalg.py @@ -0,0 +1,44 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""st for linalg.""" + +import pytest +import numpy as onp +import scipy as osp + +from mindspore import context, Tensor +import mindspore.scipy as msp +from .utils import match_array + +context.set_context(mode=context.PYNATIVE_MODE) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('args', [(), (1,), (7, -1), (3, 4, 5), + (onp.ones((3, 4), dtype=onp.float32), 5, onp.random.randn(5, 2).astype(onp.float32))]) +def test_block_diag(args): + """ + Feature: ALL TO ALL + Description: test cases for block_diag + Expectation: the result match scipy + """ + tensor_args = tuple([Tensor(arg) for arg in args]) + ms_res = msp.linalg.block_diag(*tensor_args) + + scipy_res = osp.linalg.block_diag(*args) + match_array(ms_res.asnumpy(), scipy_res) diff --git a/tests/st/scipy_st/test_line_search.py b/tests/st/scipy_st/test_line_search.py new file mode 100644 index 00000000000..92c734edd7c --- /dev/null +++ b/tests/st/scipy_st/test_line_search.py @@ -0,0 +1,142 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""st for line_search.""" + +import pytest + +import numpy as onp +from scipy.optimize.linesearch import line_search_wolfe2 as osp_line_search + +import mindspore.numpy as mnp +from mindspore import context +from mindspore.scipy.optimize.line_search import line_search as msp_line_search + +from .utils import match_array + +context.set_context(mode=context.GRAPH_MODE) + + +def _scalar_func_1(np): + def f(x): + return -x - x ** 3 + x ** 4 + + def fprime(x): + return -1 - 3 * x ** 2 + 4 * x ** 3 + + return f, fprime + + +def _scalar_func_2(np): + def f(x): + return np.exp(-4 * x) + x ** 2 + + def fprime(x): + return -4 * np.exp(-4 * x) + 2 * x + + return f, fprime + + +def _scalar_func_3(np): + def f(x): + return -np.sin(10 * x) + + def fprime(x): + return -10 * np.cos(10 * x) + + return f, fprime + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('maxiter, func, x, p', [(10, _scalar_func_1, 0., 1.), + (10, _scalar_func_2, 0., 1.), + (10, _scalar_func_3, 0., 1.)]) +def test_scalar_search(maxiter, func, x, p): + """ + Feature: ALL TO ALL + Description: test cases for 1-d function + Expectation: the result match scipy + """ + osp_f, osp_fp = func(onp) + osp_x, osp_p = onp.array(x), onp.array(p) + osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter) + + msp_f, _ = func(mnp) + msp_x, msp_p = mnp.array(x), mnp.array(p) + msp_res = msp_line_search(msp_f, msp_x, msp_p, maxiter=maxiter) + + match_array(msp_res.a_k, osp_res[0], error=5) + match_array(msp_res.f_k, osp_res[3], error=5) + + +def _line_func_1(np, *args): + def f(x): + return np.dot(x, x) + + def fprime(x): + return 2 * x + + return f, fprime + + +def _line_func_2(np, *args): + def f(x): + A = args[0] + return np.dot(x, np.dot(A, x)) + 1 + + def fprime(x): + A = args[0] + return np.dot(A + A.T, x) + + return f, fprime + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('maxiter, func, x, p', + [(10, _line_func_1, [1.13689136, 0.09772497, 0.58295368, -0.39944903, 0.37005589], + [-1.30652685, 1.65813068, -0.11816405, -0.6801782, 0.66638308]), + (10, _line_func_1, [-0.52118931, -1.84306955, -0.477974, -0.47965581, 0.6203583], + [0.69845715, 0.00377089, 0.93184837, 0.33996498, -0.01568211]), + (10, _line_func_2, [0.15634897, 1.23029068, 1.20237985, -0.38732682, -0.30230275], + [-1.04855297, -1.42001794, -1.70627019, 1.9507754, -0.50965218]), + (10, _line_func_2, [0.42833187, 0.06651722, 0.3024719, -0.63432209, -0.36274117], + [-0.67246045, -0.35955316, -0.81314628, -1.7262826, 0.17742614])]) +def test_line_search(maxiter, func, x, p): + """ + Feature: ALL TO ALL + Description: test cases for n-d function + Expectation: the result match scipy + """ + A = [[1.76405235, 0.40015721, 0.97873798, 2.2408932, 1.86755799], + [-0.97727788, 0.95008842, -0.15135721, -0.10321885, 0.4105985], + [0.14404357, 1.45427351, 0.76103773, 0.12167502, 0.44386323], + [0.33367433, 1.49407907, -0.20515826, 0.3130677, -0.85409574], + [-2.55298982, 0.6536186, 0.8644362, -0.74216502, 2.26975462]] + + osp_x, osp_p, osp_A = onp.array(x), onp.array(p), onp.array(A) + osp_f, osp_fp = func(onp, osp_A) + osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter) + + msp_x, msp_p, msp_A = mnp.array(x), mnp.array(p), mnp.array(A) + msp_f, _ = func(mnp, msp_A) + msp_res = msp_line_search(msp_f, msp_x, msp_p, maxiter=maxiter) + + match_array(msp_res.a_k, osp_res[0], error=5) + match_array(msp_res.f_k, osp_res[3], error=5) diff --git a/tests/st/scipy_st/test_optimize.py b/tests/st/scipy_st/test_optimize.py new file mode 100644 index 00000000000..c1ccd567699 --- /dev/null +++ b/tests/st/scipy_st/test_optimize.py @@ -0,0 +1,107 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""st for optimize.""" + +import pytest +import numpy as onp +import scipy as osp + +import mindspore.numpy as mnp +import mindspore.scipy as msp +from mindspore import context +from mindspore.common import Tensor + +from .utils import match_array + +context.set_context(mode=context.GRAPH_MODE) + + +def rosenbrock(np): + def func(x): + return np.sum(100. * np.diff(x) ** 2 + (1. - x[:-1]) ** 2) + + return func + + +def himmelblau(np): + def func(p): + x, y = p + return (x ** 2 + y - 11.) ** 2 + (x + y ** 2 - 7.) ** 2 + + return func + + +def matyas(np): + def func(p): + x, y = p + return 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y + + return func + + +def eggholder(np): + def func(p): + x, y = p + return - (y + 47) * np.sin(np.sqrt(np.abs(x / 2. + y + 47.))) - x * np.sin( + np.sqrt(np.abs(x - (y + 47.)))) + + return func + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype', [onp.float32]) +@pytest.mark.parametrize('func_x0', [(rosenbrock, onp.zeros(2)), + (himmelblau, onp.zeros(2)), + (himmelblau, onp.array([92, 0.001])), + (matyas, onp.ones(2)), + (eggholder, onp.ones(2) * 100.)]) +def test_bfgs(dtype, func_x0): + """ + Feature: ALL TO ALL + Description: test cases for bfgs + Expectation: the result match scipy + """ + func, x0 = func_x0 + x0 = x0.astype(dtype) + x0_tensor = Tensor(x0) + ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', + options=dict(maxiter=None, gtol=1e-6)).x + scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS').x + match_array(ms_res.asnumpy(), scipy_res, error=5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype', [onp.float32]) +def test_bfgs_fixes4594(dtype): + """ + Feature: ALL TO ALL + Description: test cases for bfgs + Expectation: the result match scipy + """ + n = 2 + A = Tensor(onp.eye(n, dtype=dtype)) * 1e4 + + def func(x): + return mnp.mean((mnp.dot(A, x)) ** 2) + + results = msp.optimize.minimize(func, Tensor(onp.ones(n, dtype=dtype)), method='BFGS', + options=dict(maxiter=None, gtol=1e-6)).x + onp.testing.assert_allclose(results.asnumpy(), onp.zeros(n, dtype=dtype), rtol=1e-6, atol=1e-6) diff --git a/tests/st/scipy_st/utils.py b/tests/st/scipy_st/utils.py new file mode 100644 index 00000000000..0c407c9302a --- /dev/null +++ b/tests/st/scipy_st/utils.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utility functions for mindspore.scipy st tests""" +import numpy as onp +from mindspore import Tensor +import mindspore.numpy as mnp + + +def to_tensor(obj, dtype=None): + if dtype is None: + res = Tensor(obj) + if res.dtype == mnp.float64: + res = res.astype(mnp.float32) + if res.dtype == mnp.int64: + res = res.astype(mnp.int32) + else: + res = Tensor(obj, dtype) + return res + + +def match_array(actual, expected, error=0): + if isinstance(actual, int): + actual = onp.asarray(actual) + + if isinstance(expected, (int, tuple)): + expected = onp.asarray(expected) + + if error > 0: + onp.testing.assert_almost_equal(actual, expected, decimal=error) + else: + onp.testing.assert_equal(actual, expected)