add cpu op fp64 register to fix bugs of gmres for backend cpu

This commit is contained in:
z00512249 2021-11-17 15:05:28 +08:00
parent 36517de69c
commit fe9442761c
15 changed files with 184 additions and 189 deletions

View File

@ -64,18 +64,6 @@ class ArithmeticCPUKernel : public CPUKernel {
std::vector<size_t> output_element_num_;
};
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCPUKernel, double);
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int32_t);

View File

@ -40,7 +40,8 @@ class ConcatCPUKernel : public CPUKernel {
int axis_{0};
};
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, float);
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, float)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, double)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int8_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int16_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int32_t)

View File

@ -58,7 +58,11 @@ MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutpu
MS_REG_CPU_KERNEL(
Reshape,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
ReshapeCPUKernel)
MS_REG_CPU_KERNEL(
Reshape,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
ReshapeCPUKernel)
MS_REG_CPU_KERNEL(
Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);

View File

@ -44,6 +44,7 @@ class TensorAddCPUKernel : public CPUKernel {
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, float);
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, double);
} // namespace kernel
} // namespace mindspore

View File

@ -23,6 +23,7 @@ add_op_info = CpuRegOp("Add") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()

View File

@ -20,6 +20,7 @@ concat_op_info = CpuRegOp("Concat") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \

View File

@ -21,6 +21,7 @@ div_op_info = CpuRegOp("Div") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()

View File

@ -20,7 +20,9 @@ mul_op_info = CpuRegOp("Mul") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()

View File

@ -21,6 +21,7 @@ pow_op_info = CpuRegOp("Pow") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()

View File

@ -20,9 +20,9 @@ real_div_op_info = CpuRegOp("RealDiv") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()

View File

@ -21,6 +21,7 @@ sub_op_info = CpuRegOp("Sub") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()

View File

@ -1,22 +1,122 @@
import scipy.sparse.linalg
import scipy.linalg
from mindspore import Tensor, context
from mindspore.scipy.sparse import gmres
import numpy as onp
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""st for scipy.sparse.linalg."""
import pytest
import numpy as onp
import scipy as osp
from scipy.sparse.linalg import cg as osp_cg
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from mindspore.scipy.sparse.linalg import cg as msp_cg
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix
onp.random.seed(0)
def gmres_compare_with_scipy(A, b, x):
gmres_x, _ = gmres(Tensor(A), Tensor(b), Tensor(
gmres_x, _ = msp.sparse.linalg.gmres(Tensor(A), Tensor(b), Tensor(
x), tol=1e-07, atol=0, solve_method='incremental')
scipy_x, _ = scipy.sparse.linalg.gmres(A, b, x, tol=1e-07, atol=0)
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x, tol=1e-07, atol=0)
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=5)
def _fetch_preconditioner(preconditioner, A):
"""
Returns one of various preconditioning matrices depending on the identifier
`preconditioner' and the input matrix A whose inverse it supposedly
approximates.
"""
if preconditioner == 'identity':
M = onp.eye(A.shape[0], dtype=A.dtype)
elif preconditioner == 'random':
random_metrix = create_sym_pos_matrix(A.shape, A.dtype)
M = onp.linalg.inv(random_metrix)
elif preconditioner == 'exact':
M = onp.linalg.inv(A)
else:
M = None
return M
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
@pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 3])
def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for cg
Expectation: the result match scipy
"""
dtype, tol = dtype_tol
A = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
osp_res = osp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
A = Tensor(A)
b = Tensor(b)
M = Tensor(M) if M is not None else M
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
@pytest.mark.parametrize('shape', [(2, 2)])
def test_cg_against_numpy(dtype, shape):
"""
Feature: ALL TO ALL
Description: test cases for cg
Expectation: the result match numpy
"""
A = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
expected = onp.linalg.solve(A, b)
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
actual_dyn, _ = msp_cg(Tensor(A), Tensor(b))
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
actual_sta, _ = msp_cg(Tensor(A), Tensor(b))
kw = {"atol": 1e-5, "rtol": 1e-5}
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -87,3 +187,59 @@ def test_gmres_incremental_against_scipy_gpu_graph(n, dtype):
A = onp.random.rand(n, n).astype(dtype)
b = onp.random.rand(n).astype(dtype)
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('dtype', [onp.float64])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 2])
def test_pynative_batched_gmres_against_scipy(n, dtype, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
"""
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
M = _fetch_preconditioner(preconditioner, a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
M = Tensor(M) if M is not None else M
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6,
solve_method='batched')
assert onp.allclose(msp_x.asnumpy(), osp_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5, 6])
@pytest.mark.parametrize('dtype', [onp.float64])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 2])
def test_graph_batched_gmres_against_scipy(n, dtype, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
"""
context.set_context(mode=context.GRAPH_MODE)
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
M = _fetch_preconditioner(preconditioner, a)
M = Tensor(M) if M is not None else M
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=0.0)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=0.0, solve_method='batched')
assert onp.allclose(msp_x.asnumpy(), osp_x)

View File

@ -22,7 +22,7 @@ import scipy as osp
import mindspore.scipy as msp
from mindspore import context, Tensor
import mindspore.numpy as mnp
from .utils import match_array, create_full_rank_matrix, create_sym_pos_matrix
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix
onp.random.seed(0)
context.set_context(mode=context.PYNATIVE_MODE)

View File

@ -24,8 +24,7 @@ import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from mindspore.scipy.optimize.line_search import line_search as msp_line_search
from .utils import match_array
from tests.st.scipy_st.utils import match_array
def rosenbrock(np):

View File

@ -1,161 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""st for scipy.sparse.linalg."""
import pytest
import numpy as onp
import scipy as osp
from scipy.sparse.linalg import cg as osp_cg
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from mindspore.scipy.sparse.linalg import cg as msp_cg
from .utils import create_sym_pos_matrix, create_full_rank_matrix
onp.random.seed(0)
def _fetch_preconditioner(preconditioner, A):
"""
Returns one of various preconditioning matrices depending on the identifier
`preconditioner' and the input matrix A whose inverse it supposedly
approximates.
"""
if preconditioner == 'identity':
M = onp.eye(A.shape[0], dtype=A.dtype)
elif preconditioner == 'random':
random_metrix = create_sym_pos_matrix(A.shape, A.dtype)
M = onp.linalg.inv(random_metrix)
elif preconditioner == 'exact':
M = onp.linalg.inv(A)
else:
M = None
return M
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
@pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 3])
def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for cg
Expectation: the result match scipy
"""
dtype, tol = dtype_tol
A = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
osp_res = osp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
A = Tensor(A)
b = Tensor(b)
M = Tensor(M) if M is not None else M
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
@pytest.mark.parametrize('shape', [(2, 2)])
def test_cg_against_numpy(dtype, shape):
"""
Feature: ALL TO ALL
Description: test cases for cg
Expectation: the result match numpy
"""
A = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
expected = onp.linalg.solve(A, b)
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
actual_dyn, _ = msp_cg(Tensor(A), Tensor(b))
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
actual_sta, _ = msp_cg(Tensor(A), Tensor(b))
kw = {"atol": 1e-5, "rtol": 1e-5}
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('dtype', [onp.float64])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 2])
def test_pynative_batched_gmres_against_scipy(n, dtype, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
"""
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
M = _fetch_preconditioner(preconditioner, a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
M = Tensor(M) if M is not None else M
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6,
solve_method='batched')
assert onp.allclose(msp_x.asnumpy(), osp_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5, 6])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_graph_batched_gmres_against_scipy(n, dtype):
"""
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
"""
context.set_context(mode=context.GRAPH_MODE)
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
osp_x, _ = osp.sparse.linalg.gmres(a, b, atol=0.0)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, atol=0.0, solve_method='batched')
assert onp.allclose(msp_x.asnumpy(), osp_x)