!10354 Add batch dot op

From: @anrui-wang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-30 15:46:21 +08:00 committed by Gitee
commit 9e27ca929a
3 changed files with 397 additions and 1 deletions

View File

@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero, tensor_dot from .math_ops import count_nonzero, tensor_dot, batch_dot
from .array_ops import repeat_elements, sequence_mask from .array_ops import repeat_elements, sequence_mask
@ -53,5 +53,6 @@ __all__ = [
'clip_by_global_norm', 'clip_by_global_norm',
'count_nonzero', 'count_nonzero',
'tensor_dot', 'tensor_dot',
'batch_dot',
'repeat_elements', 'repeat_elements',
'sequence_mask'] 'sequence_mask']

View File

@ -312,3 +312,171 @@ def dot(x1, x2):
mul_result = matmul_op(x1_reshape, x2_reshape) mul_result = matmul_op(x1_reshape, x2_reshape)
return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:]) return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
return matmul_op(x1, x2) return matmul_op(x1, x2)
@constexpr
def _get_batch_size(x1_shape, x2_shape):
"""
Get batch sizes from two inputs
"""
if len(x1_shape) < 2 or len(x2_shape) < 2:
raise ValueError("Require both inputs with rank >= 2.")
return x1_shape[0], x2_shape[0]
@constexpr
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
"""
Check whether axes are valid and cast axes from tuple to list
"""
if axes is None:
if len(x2_shape) == 2:
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
else:
axes = [len(x1_shape) - 1, len(x2_shape) - 2]
if isinstance(axes, (list, tuple)):
if 0 in axes:
raise ValueError("Batch dim cannot be used as in axes.")
if len(axes) != 2:
raise ValueError("Require two axes inputs, given less")
if isinstance(axes, tuple):
axes = list(axes)
for sub_axes in axes:
if isinstance(sub_axes, (list, tuple)):
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
# Reverse if axis < 0
if axes[0] < 0:
axes[0] += len(x1_shape)
if axes[1] < 0:
axes[1] += len(x2_shape)
elif isinstance(axes, int):
if axes == 0:
raise ValueError("Batch dim cannot be used as in axes.")
if axes < 0:
axes = [axes + len(x1_shape), axes + len(x2_shape)]
elif axes > len(x1_shape) or axes > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
else:
axes = [axes, axes]
else:
raise ValueError(
"Axes type must be one of those: int, tuple(int), list(int).")
return axes
@constexpr
def _calc_new_shape_batchdot(shape, axes, position=0):
"""
Calculate transpose and reshape parameters for input transformations,
'position' refers to whether tensor is first or second in the op.
"""
axis = axes[position]
contraction_axes = tuple([axis])
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
free_dims = tuple(shape[i] for i in free_axes)
prod_free = int(np.prod(free_dims))
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
transpose_perm = tuple([0]) + transpose_perm
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
new_shape = tuple([shape[0]]) + new_shape
return new_shape, transpose_perm, free_dims
@constexpr
def _check_batch_size(x1_batch_size, x2_batch_size):
"""
Check whether batch size of two inputs are the same
"""
if x1_batch_size != x2_batch_size:
raise ValueError("Require both inputs with the same batch sizes.")
@constexpr
def _get_output_shape(batch_size, x1_ret, x2_ret):
"""
Compute output shape for batch dot
"""
output_shape = tuple([batch_size]) + x1_ret + x2_ret
return output_shape
def batch_dot(x1, x2, axes=None):
"""
Computation of batch dot product between samples in two tensors containing batch dims.
Inputs:
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
be same as x1's.
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
Outputs:
Tensor, batch dot product of x1 and x2.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (-1, -2)
>>> output = C.batch_dot(input_x1, input_x2, axes)
>>> print(output)
[[[3. 3.]
[3. 3.]]
[[3. 3.]
[3. 3.]]]
"""
transpose_op = P.Transpose()
batch_matmul_op = P.BatchMatMul()
squeeze_one_op = P.Squeeze(1)
squeeze_minus_one_op = P.Squeeze(-1)
# input validity checks
x1_shape = F.shape(x1)
x2_shape = F.shape(x2)
x1_dim_num = len(x1_shape)
x2_dim_num = len(x2_shape)
x1_type = F.dtype(x1)
x2_type = F.dtype(x2)
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
_typecheck_input(x1_type, x2_type)
_check_batch_size(x1_batch_size, x2_batch_size)
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
if x1_dim_num == 2:
x1 = F.expand_dims(x1, 1)
axes[0] += 1
if x2_dim_num == 2:
x2 = F.expand_dims(x2, 2)
x1_shape = F.shape(x1)
x2_shape = F.shape(x2)
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
x1_transposed = transpose_op(x1, x1_transpose_fwd)
x2_transposed = transpose_op(x2, x2_transpose_fwd)
x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
# Batch matmal op part
mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
final_result = F.reshape(mul_result, output_shape)
# if the original dims are expanded, restore them from 3 to 2
if x1_dim_num == 2:
final_result = squeeze_one_op(final_result)
elif x2_dim_num == 2:
final_result = squeeze_minus_one_op(final_result)
return final_result

View File

@ -0,0 +1,227 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import composite as C
class NetBatchDot(nn.Cell):
def __init__(self, axes):
super(NetBatchDot, self).__init__()
self.axes = axes
def construct(self, x, y):
return C.batch_dot(x, y, self.axes)
# Implementation with numpy in tensorflow
def _reference_batch_dot(x, y, axes):
if isinstance(axes, int):
axes = [axes, axes]
elif isinstance(axes, tuple):
axes = list(axes)
if axes is None:
if y.ndim == 2:
axes = [x.ndim - 1, y.ndim - 1]
else:
axes = [x.ndim - 1, y.ndim - 2]
if axes[0] < 0:
axes[0] += x.ndim
if axes[1] < 0:
axes[1] += y.ndim
result = []
axes = [axes[0] - 1, axes[1] - 1]
for xi, yi in zip(x, y):
result.append(np.tensordot(xi, yi, axes))
result = np.array(result)
if result.ndim == 1:
result = np.expand_dims(result, -1)
return result
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_batch_dot_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
np.random.seed(12876)
# case 1
shape_x1 = (3, 12, 5, 2, 3)
shape_x2 = (3, 1, 7, 3, 2)
axes = (-1, -2)
x1 = np.ones(shape=shape_x1).astype(np.float32)
x2 = np.ones(shape=shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 2
shape_x1 = (4, 3, 7, 5)
shape_x2 = (4, 1, 7, 1)
axes = 2
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 3
shape_x1 = (18, 3, 5, 7)
shape_x2 = (18, 1, 3, 7)
axes = -1
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 4
shape_x1 = (2, 11, 3, 9)
shape_x2 = (2, 7, 9, 3)
axes = None
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 5
shape_x1 = (7, 5)
shape_x2 = (7, 5)
axes = None
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 6
shape_x1 = (7, 3, 5)
shape_x2 = (7, 5)
axes = None
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 7
shape_x1 = (7, 5)
shape_x2 = (7, 5, 3)
axes = None
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 8
shape_x1 = (39, 6)
shape_x2 = (39, 6)
axes = -1
x1 = np.random.random(shape_x1).astype(np.float32)
x2 = np.random.random(shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 9
shape_x1 = (21, 2, 3)
shape_x2 = (21, 3, 2)
axes = (-1, -2)
x1 = np.ones(shape=shape_x1).astype(np.float32)
x2 = np.ones(shape=shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 10
shape_x1 = (4, 3, 2, 1, 7, 5)
shape_x2 = (4, 5, 7, 1)
axes = -2
x1 = np.ones(shape=shape_x1).astype(np.float32)
x2 = np.ones(shape=shape_x2).astype(np.float32)
x1_tensor = Tensor(x1, dtype=mindspore.float32)
x2_tensor = Tensor(x2, dtype=mindspore.float32)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 10
shape_x1 = (4, 3, 2, 1, 7, 5)
shape_x2 = (4, 5, 7, 1)
axes = -2
x1 = np.ones(shape=shape_x1).astype(np.float16)
x2 = np.ones(shape=shape_x2).astype(np.float16)
x1_tensor = Tensor(x1, dtype=mindspore.float16)
x2_tensor = Tensor(x2, dtype=mindspore.float16)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)