forked from mindspore-Ecosystem/mindspore
commit
9e27ca929a
|
@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
|
||||||
from .multitype_ops.ones_like_impl import ones_like
|
from .multitype_ops.ones_like_impl import ones_like
|
||||||
from .multitype_ops.zeros_like_impl import zeros_like
|
from .multitype_ops.zeros_like_impl import zeros_like
|
||||||
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
||||||
from .math_ops import count_nonzero, tensor_dot
|
from .math_ops import count_nonzero, tensor_dot, batch_dot
|
||||||
from .array_ops import repeat_elements, sequence_mask
|
from .array_ops import repeat_elements, sequence_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,5 +53,6 @@ __all__ = [
|
||||||
'clip_by_global_norm',
|
'clip_by_global_norm',
|
||||||
'count_nonzero',
|
'count_nonzero',
|
||||||
'tensor_dot',
|
'tensor_dot',
|
||||||
|
'batch_dot',
|
||||||
'repeat_elements',
|
'repeat_elements',
|
||||||
'sequence_mask']
|
'sequence_mask']
|
||||||
|
|
|
@ -312,3 +312,171 @@ def dot(x1, x2):
|
||||||
mul_result = matmul_op(x1_reshape, x2_reshape)
|
mul_result = matmul_op(x1_reshape, x2_reshape)
|
||||||
return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
|
return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
|
||||||
return matmul_op(x1, x2)
|
return matmul_op(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _get_batch_size(x1_shape, x2_shape):
|
||||||
|
"""
|
||||||
|
Get batch sizes from two inputs
|
||||||
|
"""
|
||||||
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
||||||
|
raise ValueError("Require both inputs with rank >= 2.")
|
||||||
|
return x1_shape[0], x2_shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
||||||
|
"""
|
||||||
|
Check whether axes are valid and cast axes from tuple to list
|
||||||
|
"""
|
||||||
|
if axes is None:
|
||||||
|
if len(x2_shape) == 2:
|
||||||
|
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
||||||
|
else:
|
||||||
|
axes = [len(x1_shape) - 1, len(x2_shape) - 2]
|
||||||
|
|
||||||
|
if isinstance(axes, (list, tuple)):
|
||||||
|
if 0 in axes:
|
||||||
|
raise ValueError("Batch dim cannot be used as in axes.")
|
||||||
|
if len(axes) != 2:
|
||||||
|
raise ValueError("Require two axes inputs, given less")
|
||||||
|
if isinstance(axes, tuple):
|
||||||
|
axes = list(axes)
|
||||||
|
for sub_axes in axes:
|
||||||
|
if isinstance(sub_axes, (list, tuple)):
|
||||||
|
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
|
||||||
|
# Reverse if axis < 0
|
||||||
|
if axes[0] < 0:
|
||||||
|
axes[0] += len(x1_shape)
|
||||||
|
if axes[1] < 0:
|
||||||
|
axes[1] += len(x2_shape)
|
||||||
|
elif isinstance(axes, int):
|
||||||
|
if axes == 0:
|
||||||
|
raise ValueError("Batch dim cannot be used as in axes.")
|
||||||
|
if axes < 0:
|
||||||
|
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
||||||
|
elif axes > len(x1_shape) or axes > len(x2_shape):
|
||||||
|
raise ValueError(
|
||||||
|
"Axes value too high for given input arrays dimensions.")
|
||||||
|
else:
|
||||||
|
axes = [axes, axes]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Axes type must be one of those: int, tuple(int), list(int).")
|
||||||
|
return axes
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _calc_new_shape_batchdot(shape, axes, position=0):
|
||||||
|
"""
|
||||||
|
Calculate transpose and reshape parameters for input transformations,
|
||||||
|
'position' refers to whether tensor is first or second in the op.
|
||||||
|
"""
|
||||||
|
axis = axes[position]
|
||||||
|
contraction_axes = tuple([axis])
|
||||||
|
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
|
||||||
|
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
|
||||||
|
free_dims = tuple(shape[i] for i in free_axes)
|
||||||
|
prod_free = int(np.prod(free_dims))
|
||||||
|
|
||||||
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
||||||
|
transpose_perm = tuple([0]) + transpose_perm
|
||||||
|
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
||||||
|
new_shape = tuple([shape[0]]) + new_shape
|
||||||
|
return new_shape, transpose_perm, free_dims
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_batch_size(x1_batch_size, x2_batch_size):
|
||||||
|
"""
|
||||||
|
Check whether batch size of two inputs are the same
|
||||||
|
"""
|
||||||
|
if x1_batch_size != x2_batch_size:
|
||||||
|
raise ValueError("Require both inputs with the same batch sizes.")
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _get_output_shape(batch_size, x1_ret, x2_ret):
|
||||||
|
"""
|
||||||
|
Compute output shape for batch dot
|
||||||
|
"""
|
||||||
|
output_shape = tuple([batch_size]) + x1_ret + x2_ret
|
||||||
|
return output_shape
|
||||||
|
|
||||||
|
def batch_dot(x1, x2, axes=None):
|
||||||
|
"""
|
||||||
|
Computation of batch dot product between samples in two tensors containing batch dims.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
|
||||||
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
|
||||||
|
be same as x1's.
|
||||||
|
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
|
||||||
|
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
|
||||||
|
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, batch dot product of x1 and x2.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
|
||||||
|
>>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
||||||
|
>>> axes = (-1, -2)
|
||||||
|
>>> output = C.batch_dot(input_x1, input_x2, axes)
|
||||||
|
>>> print(output)
|
||||||
|
[[[3. 3.]
|
||||||
|
[3. 3.]]
|
||||||
|
[[3. 3.]
|
||||||
|
[3. 3.]]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
transpose_op = P.Transpose()
|
||||||
|
batch_matmul_op = P.BatchMatMul()
|
||||||
|
squeeze_one_op = P.Squeeze(1)
|
||||||
|
squeeze_minus_one_op = P.Squeeze(-1)
|
||||||
|
# input validity checks
|
||||||
|
x1_shape = F.shape(x1)
|
||||||
|
x2_shape = F.shape(x2)
|
||||||
|
x1_dim_num = len(x1_shape)
|
||||||
|
x2_dim_num = len(x2_shape)
|
||||||
|
x1_type = F.dtype(x1)
|
||||||
|
x2_type = F.dtype(x2)
|
||||||
|
|
||||||
|
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
|
||||||
|
|
||||||
|
_typecheck_input(x1_type, x2_type)
|
||||||
|
_check_batch_size(x1_batch_size, x2_batch_size)
|
||||||
|
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
|
||||||
|
|
||||||
|
if x1_dim_num == 2:
|
||||||
|
x1 = F.expand_dims(x1, 1)
|
||||||
|
axes[0] += 1
|
||||||
|
if x2_dim_num == 2:
|
||||||
|
x2 = F.expand_dims(x2, 2)
|
||||||
|
|
||||||
|
x1_shape = F.shape(x1)
|
||||||
|
x2_shape = F.shape(x2)
|
||||||
|
|
||||||
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
|
||||||
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
|
||||||
|
output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
|
||||||
|
|
||||||
|
x1_transposed = transpose_op(x1, x1_transpose_fwd)
|
||||||
|
x2_transposed = transpose_op(x2, x2_transpose_fwd)
|
||||||
|
x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
|
||||||
|
x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
|
||||||
|
|
||||||
|
# Batch matmal op part
|
||||||
|
mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
|
||||||
|
|
||||||
|
final_result = F.reshape(mul_result, output_shape)
|
||||||
|
|
||||||
|
# if the original dims are expanded, restore them from 3 to 2
|
||||||
|
if x1_dim_num == 2:
|
||||||
|
final_result = squeeze_one_op(final_result)
|
||||||
|
elif x2_dim_num == 2:
|
||||||
|
final_result = squeeze_minus_one_op(final_result)
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
|
@ -0,0 +1,227 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
from mindspore import Tensor
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
|
||||||
|
|
||||||
|
class NetBatchDot(nn.Cell):
|
||||||
|
def __init__(self, axes):
|
||||||
|
super(NetBatchDot, self).__init__()
|
||||||
|
self.axes = axes
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return C.batch_dot(x, y, self.axes)
|
||||||
|
|
||||||
|
|
||||||
|
# Implementation with numpy in tensorflow
|
||||||
|
def _reference_batch_dot(x, y, axes):
|
||||||
|
if isinstance(axes, int):
|
||||||
|
axes = [axes, axes]
|
||||||
|
elif isinstance(axes, tuple):
|
||||||
|
axes = list(axes)
|
||||||
|
if axes is None:
|
||||||
|
if y.ndim == 2:
|
||||||
|
axes = [x.ndim - 1, y.ndim - 1]
|
||||||
|
else:
|
||||||
|
axes = [x.ndim - 1, y.ndim - 2]
|
||||||
|
if axes[0] < 0:
|
||||||
|
axes[0] += x.ndim
|
||||||
|
if axes[1] < 0:
|
||||||
|
axes[1] += y.ndim
|
||||||
|
result = []
|
||||||
|
axes = [axes[0] - 1, axes[1] - 1]
|
||||||
|
for xi, yi in zip(x, y):
|
||||||
|
result.append(np.tensordot(xi, yi, axes))
|
||||||
|
result = np.array(result)
|
||||||
|
if result.ndim == 1:
|
||||||
|
result = np.expand_dims(result, -1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_batch_dot_fp32():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
np.random.seed(12876)
|
||||||
|
|
||||||
|
# case 1
|
||||||
|
shape_x1 = (3, 12, 5, 2, 3)
|
||||||
|
shape_x2 = (3, 1, 7, 3, 2)
|
||||||
|
axes = (-1, -2)
|
||||||
|
x1 = np.ones(shape=shape_x1).astype(np.float32)
|
||||||
|
x2 = np.ones(shape=shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 2
|
||||||
|
shape_x1 = (4, 3, 7, 5)
|
||||||
|
shape_x2 = (4, 1, 7, 1)
|
||||||
|
axes = 2
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 3
|
||||||
|
shape_x1 = (18, 3, 5, 7)
|
||||||
|
shape_x2 = (18, 1, 3, 7)
|
||||||
|
axes = -1
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 4
|
||||||
|
shape_x1 = (2, 11, 3, 9)
|
||||||
|
shape_x2 = (2, 7, 9, 3)
|
||||||
|
axes = None
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 5
|
||||||
|
shape_x1 = (7, 5)
|
||||||
|
shape_x2 = (7, 5)
|
||||||
|
axes = None
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 6
|
||||||
|
shape_x1 = (7, 3, 5)
|
||||||
|
shape_x2 = (7, 5)
|
||||||
|
axes = None
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 7
|
||||||
|
shape_x1 = (7, 5)
|
||||||
|
shape_x2 = (7, 5, 3)
|
||||||
|
axes = None
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 8
|
||||||
|
shape_x1 = (39, 6)
|
||||||
|
shape_x2 = (39, 6)
|
||||||
|
axes = -1
|
||||||
|
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||||
|
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 9
|
||||||
|
shape_x1 = (21, 2, 3)
|
||||||
|
shape_x2 = (21, 3, 2)
|
||||||
|
axes = (-1, -2)
|
||||||
|
x1 = np.ones(shape=shape_x1).astype(np.float32)
|
||||||
|
x2 = np.ones(shape=shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 10
|
||||||
|
shape_x1 = (4, 3, 2, 1, 7, 5)
|
||||||
|
shape_x2 = (4, 5, 7, 1)
|
||||||
|
axes = -2
|
||||||
|
x1 = np.ones(shape=shape_x1).astype(np.float32)
|
||||||
|
x2 = np.ones(shape=shape_x2).astype(np.float32)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
||||||
|
|
||||||
|
# case 10
|
||||||
|
shape_x1 = (4, 3, 2, 1, 7, 5)
|
||||||
|
shape_x2 = (4, 5, 7, 1)
|
||||||
|
axes = -2
|
||||||
|
x1 = np.ones(shape=shape_x1).astype(np.float16)
|
||||||
|
x2 = np.ones(shape=shape_x2).astype(np.float16)
|
||||||
|
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||||
|
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||||
|
|
||||||
|
network = NetBatchDot(axes)
|
||||||
|
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||||
|
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||||
|
assert np.allclose(ms_result_np, tf_result)
|
Loading…
Reference in New Issue