forked from mindspore-Ecosystem/mindspore
move matmul from numpy to ops
This commit is contained in:
parent
e5aedcca47
commit
3a319a97e9
|
@ -31,8 +31,8 @@ from .array_ops import ravel, expand_dims
|
||||||
|
|
||||||
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
||||||
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
|
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
|
||||||
_raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \
|
_raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \
|
||||||
_max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range
|
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range
|
||||||
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
|
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
|
||||||
_check_input_tensor
|
_check_input_tensor
|
||||||
|
|
||||||
|
@ -1285,44 +1285,7 @@ def matmul(x1, x2, dtype=None):
|
||||||
[ 550. 620. 690. 760. 830.]
|
[ 550. 620. 690. 760. 830.]
|
||||||
[ 670. 756. 842. 928. 1014.]]]
|
[ 670. 756. 842. 928. 1014.]]]
|
||||||
"""
|
"""
|
||||||
# performs type promotion
|
return C.matmul(x1, x2, dtype=dtype)
|
||||||
dtype1 = F.dtype(x1)
|
|
||||||
dtype2 = F.dtype(x2)
|
|
||||||
dtype_out = _promote(dtype1, dtype2)
|
|
||||||
if not _check_same_type(dtype1, dtype_out):
|
|
||||||
x1 = F.cast(x1, dtype_out)
|
|
||||||
if not _check_same_type(dtype2, dtype_out):
|
|
||||||
x2 = F.cast(x2, dtype_out)
|
|
||||||
|
|
||||||
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
|
|
||||||
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
|
|
||||||
_check_matmul_shapes(shape1_orig, shape2_orig)
|
|
||||||
ndim_aligned = _max(ndim1_orig, ndim2_orig)
|
|
||||||
transpose_b = ndim2_orig == 1
|
|
||||||
shape_backbone = _infer_out_shape(
|
|
||||||
shape1_orig[:-2], shape2_orig[:-2])
|
|
||||||
# infers the shape of the output
|
|
||||||
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
|
|
||||||
ndim1_orig, ndim2_orig, transpose_b)
|
|
||||||
|
|
||||||
x1 = _expand(x1, _max(ndim_aligned, 2))
|
|
||||||
x2 = _expand(x2, _max(ndim_aligned, 2))
|
|
||||||
shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
|
|
||||||
|
|
||||||
if ndim_aligned <= 2:
|
|
||||||
res = P.MatMul(False, transpose_b)(x1, x2)
|
|
||||||
else:
|
|
||||||
# broadcasts x1.shape[:-2] with x2.shape[:-2]
|
|
||||||
shape_aligned = shape_backbone + _infer_shape_rem(shape1_aligned, shape2_aligned,
|
|
||||||
ndim_aligned, ndim_aligned,
|
|
||||||
transpose_b)
|
|
||||||
x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_aligned[:-2], ndim_aligned)
|
|
||||||
x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_aligned[:-2], ndim_aligned)
|
|
||||||
res = P.BatchMatMul(False, transpose_b)(x1, x2)
|
|
||||||
|
|
||||||
if dtype is not None and not _check_same_type(dtype_out, dtype):
|
|
||||||
res = F.cast(res, dtype)
|
|
||||||
return F.reshape(res, shape_out)
|
|
||||||
|
|
||||||
|
|
||||||
def square(x, out=None, where=True, dtype=None):
|
def square(x, out=None, where=True, dtype=None):
|
||||||
|
@ -2256,20 +2219,6 @@ def _shape_reduced(shape, axes):
|
||||||
return tuple(shape_out)
|
return tuple(shape_out)
|
||||||
|
|
||||||
|
|
||||||
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
|
||||||
"""Infers the shape of the last two dimensions after performing matmul."""
|
|
||||||
shape_rem = ()
|
|
||||||
if ndim1 >= 2:
|
|
||||||
shape_rem += (shape1[-2],)
|
|
||||||
if transpose_b:
|
|
||||||
if ndim2 >= 2:
|
|
||||||
shape_rem += (shape2[-2],)
|
|
||||||
else:
|
|
||||||
if ndim1 >= 1:
|
|
||||||
shape_rem += (shape2[-1],)
|
|
||||||
return shape_rem
|
|
||||||
|
|
||||||
|
|
||||||
def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True):
|
def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True):
|
||||||
"""Applies comparison based on cmp_fn and reduction based on reduce_fn"""
|
"""Applies comparison based on cmp_fn and reduction based on reduce_fn"""
|
||||||
_check_input_tensor(a)
|
_check_input_tensor(a)
|
||||||
|
|
|
@ -278,17 +278,6 @@ def _check_is_int(dtype):
|
||||||
return isinstance(dtype, typing.Int)
|
return isinstance(dtype, typing.Int)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_matmul_shapes(shape1, shape2):
|
|
||||||
"""Checks shape1 and shape2 are valid shapes to perform matmul"""
|
|
||||||
ndim1, ndim2 = len(shape1), len(shape2)
|
|
||||||
if ndim1 < 1 or ndim2 < 1:
|
|
||||||
raise ValueError('input operands must have at least 1 dimension')
|
|
||||||
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
|
|
||||||
raise ValueError(f'mismatch in core dimension of input operands (size '
|
|
||||||
f'{shape1[-1]} is different from {shape2[-2]})')
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
|
def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
|
||||||
"""Check axis argument type."""
|
"""Check axis argument type."""
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
|
||||||
from .multitype_ops.ones_like_impl import ones_like
|
from .multitype_ops.ones_like_impl import ones_like
|
||||||
from .multitype_ops.zeros_like_impl import zeros_like
|
from .multitype_ops.zeros_like_impl import zeros_like
|
||||||
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
||||||
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot
|
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul
|
||||||
from .array_ops import repeat_elements, sequence_mask
|
from .array_ops import repeat_elements, sequence_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,4 +56,5 @@ __all__ = [
|
||||||
'dot',
|
'dot',
|
||||||
'batch_dot',
|
'batch_dot',
|
||||||
'repeat_elements',
|
'repeat_elements',
|
||||||
'sequence_mask']
|
'sequence_mask',
|
||||||
|
'matmul']
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""math Operations."""
|
"""math Operations."""
|
||||||
|
from itertools import zip_longest
|
||||||
|
from collections import deque
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
@ -486,3 +488,147 @@ def batch_dot(x1, x2, axes=None):
|
||||||
final_result = squeeze_minus_one_op(final_result)
|
final_result = squeeze_minus_one_op(final_result)
|
||||||
|
|
||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_same_type(dtype1, dtype2):
|
||||||
|
return dtype1 == dtype2
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _max(*args):
|
||||||
|
"""Returns the maximum value."""
|
||||||
|
return max(*args)
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _min(*args):
|
||||||
|
"""Returns the minimum value."""
|
||||||
|
return min(*args)
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
||||||
|
"""Infers the shape of the last two dimensions after performing matmul."""
|
||||||
|
shape_rem = []
|
||||||
|
if ndim1 >= 2:
|
||||||
|
shape_rem.append(shape1[-2])
|
||||||
|
if transpose_b:
|
||||||
|
if ndim2 >= 2:
|
||||||
|
shape_rem.append(shape2[-2])
|
||||||
|
else:
|
||||||
|
if ndim1 >= 1:
|
||||||
|
shape_rem.append(shape2[-1])
|
||||||
|
return tuple(shape_rem)
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_matmul_shapes(shape1, shape2):
|
||||||
|
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
|
||||||
|
ndim1, ndim2 = len(shape1), len(shape2)
|
||||||
|
if ndim1 < 1 or ndim2 < 1:
|
||||||
|
raise ValueError('input operands must have at least 1 dimension')
|
||||||
|
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
|
||||||
|
raise ValueError(f'mismatch in core dimension of input operands (size '
|
||||||
|
f'{shape1[-1]} is different from {shape2[-2]})')
|
||||||
|
shape_out = deque()
|
||||||
|
for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
|
||||||
|
max_size = max(items)
|
||||||
|
if any(item not in (1, max_size) for item in items):
|
||||||
|
raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}')
|
||||||
|
shape_out.appendleft(max_size)
|
||||||
|
return tuple(shape_out)
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _tile_size(shape, out_shape, ndim):
|
||||||
|
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||||
|
size = [1]*ndim
|
||||||
|
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
||||||
|
if i != j:
|
||||||
|
size[idx] = j
|
||||||
|
return tuple(size)
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_need_broadcast(shape1, shape2):
|
||||||
|
"""Returns True if broadcast is necessary for batchmatmul."""
|
||||||
|
return shape1[:-2] != shape2[:-2]
|
||||||
|
|
||||||
|
def _expand(x, ndim):
|
||||||
|
"""Expand x to ndim from axis, which can be 0 or -1."""
|
||||||
|
while F.rank(x) < ndim:
|
||||||
|
x = F.expand_dims(x, 0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
||||||
|
"""Broadcasts x from shape_cur to shape_to."""
|
||||||
|
size = _tile_size(shape_cur, shape_to, ndim_to)
|
||||||
|
return F.tile(x, size)
|
||||||
|
|
||||||
|
def matmul(x1, x2, dtype=None):
|
||||||
|
"""
|
||||||
|
Returns the matrix product of two arrays.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are
|
||||||
|
not supported.
|
||||||
|
On GPU, the supported dtypes are np.float16 and np.float32.
|
||||||
|
On CPU, the supported dtypes are np.float16 and np.float32.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x1 (Tensor): Input tensor, scalar not allowed.
|
||||||
|
x2 (Tensor): Input tensor, scalar not allowed.
|
||||||
|
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||||
|
output Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor or scalar, the matrix product of the inputs. This is a scalar only
|
||||||
|
when both `x1`, `x2` are 1-d vectors.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the last dimension of `x1` is not the same size as the
|
||||||
|
second-to-last dimension of `x2`, or if a scalar value is passed in.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32')
|
||||||
|
>>> x2 = np.arange(4*5).reshape(4, 5).astype('float32')
|
||||||
|
>>> output = np.matmul(x1, x2)
|
||||||
|
>>> print(output)
|
||||||
|
[[[ 70. 76. 82. 88. 94.]
|
||||||
|
[ 190. 212. 234. 256. 278.]
|
||||||
|
[ 310. 348. 386. 424. 462.]]
|
||||||
|
[[ 430. 484. 538. 592. 646.]
|
||||||
|
[ 550. 620. 690. 760. 830.]
|
||||||
|
[ 670. 756. 842. 928. 1014.]]]
|
||||||
|
"""
|
||||||
|
# performs type promotion
|
||||||
|
dtype1 = F.dtype(x1)
|
||||||
|
dtype2 = F.dtype(x2)
|
||||||
|
if not _check_same_type(dtype1, dtype2):
|
||||||
|
x1 = x1.astype(mstype.float32)
|
||||||
|
x2 = x2.astype(mstype.float32)
|
||||||
|
|
||||||
|
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
|
||||||
|
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
|
||||||
|
transpose_b = ndim2_orig == 1
|
||||||
|
shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig)
|
||||||
|
# infers the shape of the output
|
||||||
|
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
|
||||||
|
ndim1_orig, ndim2_orig, transpose_b)
|
||||||
|
|
||||||
|
x1 = _expand(x1, 2)
|
||||||
|
x2 = _expand(x2, 2)
|
||||||
|
if F.rank(x2) == 2:
|
||||||
|
if F.rank(x1) > 2:
|
||||||
|
x1 = F.reshape(x1, (-1, shape1_orig[-1]))
|
||||||
|
res = P.MatMul(False, transpose_b)(x1, x2)
|
||||||
|
else:
|
||||||
|
# broadcasts x1.shape[:-2] with x2.shape[:-2]
|
||||||
|
ndim_aligned = _max(ndim1_orig, ndim2_orig)
|
||||||
|
x1 = _expand(x1, ndim_aligned)
|
||||||
|
x2 = _expand(x2, ndim_aligned)
|
||||||
|
shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
|
||||||
|
x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned)
|
||||||
|
x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned)
|
||||||
|
res = P.BatchMatMul(False, transpose_b)(x1, x2)
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
res = res.astype(dtype)
|
||||||
|
return F.reshape(res, shape_out)
|
||||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
|
||||||
class MatMulNet(nn.Cell):
|
class MatMulNet(nn.Cell):
|
||||||
|
@ -43,6 +44,15 @@ class MatMul_d(nn.Cell):
|
||||||
return self.matmul(x, y)
|
return self.matmul(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulComposite(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MatMulComposite, self).__init__()
|
||||||
|
self.matmul = C.matmul
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.matmul(x, y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
|
@ -77,3 +87,37 @@ def test_matmul_float64():
|
||||||
output = net(Tensor(x), Tensor(y))
|
output = net(Tensor(x), Tensor(y))
|
||||||
expect = np.matmul(x, y)
|
expect = np.matmul(x, y)
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_matmul_composite():
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
net = MatMulComposite()
|
||||||
|
|
||||||
|
scalars = [np.random.randn(1).astype(np.float32), np.random.randn(1).astype(np.float32),
|
||||||
|
np.random.randn(1, 1).astype(np.float32),
|
||||||
|
np.random.randn(1, 1, 1).astype(np.float32)]
|
||||||
|
for x in scalars:
|
||||||
|
for y in scalars:
|
||||||
|
output = net(Tensor(x), Tensor(y))
|
||||||
|
expect = np.matmul(x, y)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
broadcastables = [
|
||||||
|
np.random.randn(3).astype(np.float32), np.random.randn(3).astype(np.float32),
|
||||||
|
np.random.randn(6).astype(np.float32), np.random.randn(6, 4).astype(np.float32),
|
||||||
|
np.random.randn(5, 2).astype(np.float32), np.random.randn(2).astype(np.float32),
|
||||||
|
np.random.randn(2, 9).astype(np.float32), np.random.randn(9, 8).astype(np.float32),
|
||||||
|
np.random.randn(6).astype(np.float32), np.random.randn(2, 6, 5).astype(np.float32),
|
||||||
|
np.random.randn(9, 2, 7).astype(np.float32), np.random.randn(7).astype(np.float32),
|
||||||
|
np.random.randn(5, 2, 4).astype(np.float32), np.random.randn(6, 1, 4, 9).astype(np.float32),
|
||||||
|
np.random.randn(7, 1, 5, 3, 2).astype(np.float32), np.random.randn(8, 1, 6, 1, 2, 9).astype(np.float32)
|
||||||
|
]
|
||||||
|
for i in range(8):
|
||||||
|
x = broadcastables[2*i]
|
||||||
|
y = broadcastables[2*i + 1]
|
||||||
|
output = net(Tensor(x), Tensor(y))
|
||||||
|
expect = np.matmul(x, y)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||||
|
|
Loading…
Reference in New Issue