move matmul from numpy to ops

This commit is contained in:
huangmengxi 2021-03-01 17:21:05 +08:00
parent e5aedcca47
commit 3a319a97e9
5 changed files with 196 additions and 67 deletions

View File

@ -31,8 +31,8 @@ from .array_ops import ravel, expand_dims
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
_raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \
_max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range
_raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
_check_input_tensor
@ -1285,44 +1285,7 @@ def matmul(x1, x2, dtype=None):
[ 550. 620. 690. 760. 830.]
[ 670. 756. 842. 928. 1014.]]]
"""
# performs type promotion
dtype1 = F.dtype(x1)
dtype2 = F.dtype(x2)
dtype_out = _promote(dtype1, dtype2)
if not _check_same_type(dtype1, dtype_out):
x1 = F.cast(x1, dtype_out)
if not _check_same_type(dtype2, dtype_out):
x2 = F.cast(x2, dtype_out)
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
_check_matmul_shapes(shape1_orig, shape2_orig)
ndim_aligned = _max(ndim1_orig, ndim2_orig)
transpose_b = ndim2_orig == 1
shape_backbone = _infer_out_shape(
shape1_orig[:-2], shape2_orig[:-2])
# infers the shape of the output
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
ndim1_orig, ndim2_orig, transpose_b)
x1 = _expand(x1, _max(ndim_aligned, 2))
x2 = _expand(x2, _max(ndim_aligned, 2))
shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
if ndim_aligned <= 2:
res = P.MatMul(False, transpose_b)(x1, x2)
else:
# broadcasts x1.shape[:-2] with x2.shape[:-2]
shape_aligned = shape_backbone + _infer_shape_rem(shape1_aligned, shape2_aligned,
ndim_aligned, ndim_aligned,
transpose_b)
x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_aligned[:-2], ndim_aligned)
x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_aligned[:-2], ndim_aligned)
res = P.BatchMatMul(False, transpose_b)(x1, x2)
if dtype is not None and not _check_same_type(dtype_out, dtype):
res = F.cast(res, dtype)
return F.reshape(res, shape_out)
return C.matmul(x1, x2, dtype=dtype)
def square(x, out=None, where=True, dtype=None):
@ -2256,20 +2219,6 @@ def _shape_reduced(shape, axes):
return tuple(shape_out)
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
"""Infers the shape of the last two dimensions after performing matmul."""
shape_rem = ()
if ndim1 >= 2:
shape_rem += (shape1[-2],)
if transpose_b:
if ndim2 >= 2:
shape_rem += (shape2[-2],)
else:
if ndim1 >= 1:
shape_rem += (shape2[-1],)
return shape_rem
def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True):
"""Applies comparison based on cmp_fn and reduction based on reduce_fn"""
_check_input_tensor(a)

View File

@ -278,17 +278,6 @@ def _check_is_int(dtype):
return isinstance(dtype, typing.Int)
@constexpr
def _check_matmul_shapes(shape1, shape2):
"""Checks shape1 and shape2 are valid shapes to perform matmul"""
ndim1, ndim2 = len(shape1), len(shape2)
if ndim1 < 1 or ndim2 < 1:
raise ValueError('input operands must have at least 1 dimension')
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
raise ValueError(f'mismatch in core dimension of input operands (size '
f'{shape1[-1]} is different from {shape2[-2]})')
@constexpr
def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
"""Check axis argument type."""

View File

@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul
from .array_ops import repeat_elements, sequence_mask
@ -56,4 +56,5 @@ __all__ = [
'dot',
'batch_dot',
'repeat_elements',
'sequence_mask']
'sequence_mask',
'matmul']

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""math Operations."""
from itertools import zip_longest
from collections import deque
import numpy as np
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype
@ -486,3 +488,147 @@ def batch_dot(x1, x2, axes=None):
final_result = squeeze_minus_one_op(final_result)
return final_result
@constexpr
def _check_same_type(dtype1, dtype2):
return dtype1 == dtype2
@constexpr
def _max(*args):
"""Returns the maximum value."""
return max(*args)
@constexpr
def _min(*args):
"""Returns the minimum value."""
return min(*args)
@constexpr
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
"""Infers the shape of the last two dimensions after performing matmul."""
shape_rem = []
if ndim1 >= 2:
shape_rem.append(shape1[-2])
if transpose_b:
if ndim2 >= 2:
shape_rem.append(shape2[-2])
else:
if ndim1 >= 1:
shape_rem.append(shape2[-1])
return tuple(shape_rem)
@constexpr
def _check_matmul_shapes(shape1, shape2):
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
ndim1, ndim2 = len(shape1), len(shape2)
if ndim1 < 1 or ndim2 < 1:
raise ValueError('input operands must have at least 1 dimension')
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
raise ValueError(f'mismatch in core dimension of input operands (size '
f'{shape1[-1]} is different from {shape2[-2]})')
shape_out = deque()
for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
max_size = max(items)
if any(item not in (1, max_size) for item in items):
raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}')
shape_out.appendleft(max_size)
return tuple(shape_out)
@constexpr
def _tile_size(shape, out_shape, ndim):
"""Returns tile_size such that shape*tile_size = out_shape"""
size = [1]*ndim
for idx, (i, j) in enumerate(zip(shape, out_shape)):
if i != j:
size[idx] = j
return tuple(size)
@constexpr
def _check_need_broadcast(shape1, shape2):
"""Returns True if broadcast is necessary for batchmatmul."""
return shape1[:-2] != shape2[:-2]
def _expand(x, ndim):
"""Expand x to ndim from axis, which can be 0 or -1."""
while F.rank(x) < ndim:
x = F.expand_dims(x, 0)
return x
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
"""Broadcasts x from shape_cur to shape_to."""
size = _tile_size(shape_cur, shape_to, ndim_to)
return F.tile(x, size)
def matmul(x1, x2, dtype=None):
"""
Returns the matrix product of two arrays.
Note:
Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are
not supported.
On GPU, the supported dtypes are np.float16 and np.float32.
On CPU, the supported dtypes are np.float16 and np.float32.
Args:
x1 (Tensor): Input tensor, scalar not allowed.
x2 (Tensor): Input tensor, scalar not allowed.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, the matrix product of the inputs. This is a scalar only
when both `x1`, `x2` are 1-d vectors.
Raises:
ValueError: If the last dimension of `x1` is not the same size as the
second-to-last dimension of `x2`, or if a scalar value is passed in.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32')
>>> x2 = np.arange(4*5).reshape(4, 5).astype('float32')
>>> output = np.matmul(x1, x2)
>>> print(output)
[[[ 70. 76. 82. 88. 94.]
[ 190. 212. 234. 256. 278.]
[ 310. 348. 386. 424. 462.]]
[[ 430. 484. 538. 592. 646.]
[ 550. 620. 690. 760. 830.]
[ 670. 756. 842. 928. 1014.]]]
"""
# performs type promotion
dtype1 = F.dtype(x1)
dtype2 = F.dtype(x2)
if not _check_same_type(dtype1, dtype2):
x1 = x1.astype(mstype.float32)
x2 = x2.astype(mstype.float32)
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
transpose_b = ndim2_orig == 1
shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig)
# infers the shape of the output
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
ndim1_orig, ndim2_orig, transpose_b)
x1 = _expand(x1, 2)
x2 = _expand(x2, 2)
if F.rank(x2) == 2:
if F.rank(x1) > 2:
x1 = F.reshape(x1, (-1, shape1_orig[-1]))
res = P.MatMul(False, transpose_b)(x1, x2)
else:
# broadcasts x1.shape[:-2] with x2.shape[:-2]
ndim_aligned = _max(ndim1_orig, ndim2_orig)
x1 = _expand(x1, ndim_aligned)
x2 = _expand(x2, ndim_aligned)
shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned)
x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned)
res = P.BatchMatMul(False, transpose_b)(x1, x2)
if dtype is not None:
res = res.astype(dtype)
return F.reshape(res, shape_out)

View File

@ -20,6 +20,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _inner_ops as inner
class MatMulNet(nn.Cell):
@ -43,6 +44,15 @@ class MatMul_d(nn.Cell):
return self.matmul(x, y)
class MatMulComposite(nn.Cell):
def __init__(self):
super(MatMulComposite, self).__init__()
self.matmul = C.matmul
def construct(self, x, y):
return self.matmul(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -77,3 +87,37 @@ def test_matmul_float64():
output = net(Tensor(x), Tensor(y))
expect = np.matmul(x, y)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_matmul_composite():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = MatMulComposite()
scalars = [np.random.randn(1).astype(np.float32), np.random.randn(1).astype(np.float32),
np.random.randn(1, 1).astype(np.float32),
np.random.randn(1, 1, 1).astype(np.float32)]
for x in scalars:
for y in scalars:
output = net(Tensor(x), Tensor(y))
expect = np.matmul(x, y)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
broadcastables = [
np.random.randn(3).astype(np.float32), np.random.randn(3).astype(np.float32),
np.random.randn(6).astype(np.float32), np.random.randn(6, 4).astype(np.float32),
np.random.randn(5, 2).astype(np.float32), np.random.randn(2).astype(np.float32),
np.random.randn(2, 9).astype(np.float32), np.random.randn(9, 8).astype(np.float32),
np.random.randn(6).astype(np.float32), np.random.randn(2, 6, 5).astype(np.float32),
np.random.randn(9, 2, 7).astype(np.float32), np.random.randn(7).astype(np.float32),
np.random.randn(5, 2, 4).astype(np.float32), np.random.randn(6, 1, 4, 9).astype(np.float32),
np.random.randn(7, 1, 5, 3, 2).astype(np.float32), np.random.randn(8, 1, 6, 1, 2, 9).astype(np.float32)
]
for i in range(8):
x = broadcastables[2*i]
y = broadcastables[2*i + 1]
output = net(Tensor(x), Tensor(y))
expect = np.matmul(x, y)
np.testing.assert_array_almost_equal(output.asnumpy(), expect)