diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index 7c3422edb44..eda1ccada76 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial -from .math_ops import count_nonzero, tensor_dot +from .math_ops import count_nonzero, tensor_dot, batch_dot from .array_ops import repeat_elements, sequence_mask @@ -53,5 +53,6 @@ __all__ = [ 'clip_by_global_norm', 'count_nonzero', 'tensor_dot', + 'batch_dot', 'repeat_elements', 'sequence_mask'] diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index f5586311d21..0d57c06a10e 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -312,3 +312,171 @@ def dot(x1, x2): mul_result = matmul_op(x1_reshape, x2_reshape) return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:]) return matmul_op(x1, x2) + + +@constexpr +def _get_batch_size(x1_shape, x2_shape): + """ + Get batch sizes from two inputs + """ + if len(x1_shape) < 2 or len(x2_shape) < 2: + raise ValueError("Require both inputs with rank >= 2.") + return x1_shape[0], x2_shape[0] + + +@constexpr +def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): + """ + Check whether axes are valid and cast axes from tuple to list + """ + if axes is None: + if len(x2_shape) == 2: + axes = [len(x1_shape) - 1, len(x2_shape) - 1] + else: + axes = [len(x1_shape) - 1, len(x2_shape) - 2] + + if isinstance(axes, (list, tuple)): + if 0 in axes: + raise ValueError("Batch dim cannot be used as in axes.") + if len(axes) != 2: + raise ValueError("Require two axes inputs, given less") + if isinstance(axes, tuple): + axes = list(axes) + for sub_axes in axes: + if isinstance(sub_axes, (list, tuple)): + raise ValueError("Require dimension to be in any of those: None, int, (int, int).") + # Reverse if axis < 0 + if axes[0] < 0: + axes[0] += len(x1_shape) + if axes[1] < 0: + axes[1] += len(x2_shape) + elif isinstance(axes, int): + if axes == 0: + raise ValueError("Batch dim cannot be used as in axes.") + if axes < 0: + axes = [axes + len(x1_shape), axes + len(x2_shape)] + elif axes > len(x1_shape) or axes > len(x2_shape): + raise ValueError( + "Axes value too high for given input arrays dimensions.") + else: + axes = [axes, axes] + else: + raise ValueError( + "Axes type must be one of those: int, tuple(int), list(int).") + return axes + + +@constexpr +def _calc_new_shape_batchdot(shape, axes, position=0): + """ + Calculate transpose and reshape parameters for input transformations, + 'position' refers to whether tensor is first or second in the op. + """ + axis = axes[position] + contraction_axes = tuple([axis]) + prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) + free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes) + free_dims = tuple(shape[i] for i in free_axes) + prod_free = int(np.prod(free_dims)) + + transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes + transpose_perm = tuple([0]) + transpose_perm + new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) + new_shape = tuple([shape[0]]) + new_shape + return new_shape, transpose_perm, free_dims + + +@constexpr +def _check_batch_size(x1_batch_size, x2_batch_size): + """ + Check whether batch size of two inputs are the same + """ + if x1_batch_size != x2_batch_size: + raise ValueError("Require both inputs with the same batch sizes.") + +@constexpr +def _get_output_shape(batch_size, x1_ret, x2_ret): + """ + Compute output shape for batch dot + """ + output_shape = tuple([batch_size]) + x1_ret + x2_ret + return output_shape + +def batch_dot(x1, x2, axes=None): + """ + Computation of batch dot product between samples in two tensors containing batch dims. + + Inputs: + - **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32 + - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should + be same as x1's. + - **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions + specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from + `a` input shape and last N dims from `b` input shape in order as axes for each respectively. + + Outputs: + Tensor, batch dot product of x1 and x2. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32) + >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32) + >>> axes = (-1, -2) + >>> output = C.batch_dot(input_x1, input_x2, axes) + >>> print(output) + [[[3. 3.] + [3. 3.]] + [[3. 3.] + [3. 3.]]] + """ + + transpose_op = P.Transpose() + batch_matmul_op = P.BatchMatMul() + squeeze_one_op = P.Squeeze(1) + squeeze_minus_one_op = P.Squeeze(-1) + # input validity checks + x1_shape = F.shape(x1) + x2_shape = F.shape(x2) + x1_dim_num = len(x1_shape) + x2_dim_num = len(x2_shape) + x1_type = F.dtype(x1) + x2_type = F.dtype(x2) + + x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape) + + _typecheck_input(x1_type, x2_type) + _check_batch_size(x1_batch_size, x2_batch_size) + axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes) + + if x1_dim_num == 2: + x1 = F.expand_dims(x1, 1) + axes[0] += 1 + if x2_dim_num == 2: + x2 = F.expand_dims(x2, 2) + + x1_shape = F.shape(x1) + x2_shape = F.shape(x2) + + x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0) + x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1) + output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret) + + x1_transposed = transpose_op(x1, x1_transpose_fwd) + x2_transposed = transpose_op(x2, x2_transpose_fwd) + x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd) + x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd) + + # Batch matmal op part + mul_result = batch_matmul_op(x1_reshaped, x2_reshaped) + + final_result = F.reshape(mul_result, output_shape) + + # if the original dims are expanded, restore them from 3 to 2 + if x1_dim_num == 2: + final_result = squeeze_one_op(final_result) + elif x2_dim_num == 2: + final_result = squeeze_minus_one_op(final_result) + + return final_result diff --git a/tests/st/ops/cpu/test_batchdot_op.py b/tests/st/ops/cpu/test_batchdot_op.py new file mode 100644 index 00000000000..e7817f2a682 --- /dev/null +++ b/tests/st/ops/cpu/test_batchdot_op.py @@ -0,0 +1,227 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.context as context +from mindspore.ops import composite as C + + +class NetBatchDot(nn.Cell): + def __init__(self, axes): + super(NetBatchDot, self).__init__() + self.axes = axes + + def construct(self, x, y): + return C.batch_dot(x, y, self.axes) + + +# Implementation with numpy in tensorflow +def _reference_batch_dot(x, y, axes): + if isinstance(axes, int): + axes = [axes, axes] + elif isinstance(axes, tuple): + axes = list(axes) + if axes is None: + if y.ndim == 2: + axes = [x.ndim - 1, y.ndim - 1] + else: + axes = [x.ndim - 1, y.ndim - 2] + if axes[0] < 0: + axes[0] += x.ndim + if axes[1] < 0: + axes[1] += y.ndim + result = [] + axes = [axes[0] - 1, axes[1] - 1] + for xi, yi in zip(x, y): + result.append(np.tensordot(xi, yi, axes)) + result = np.array(result) + if result.ndim == 1: + result = np.expand_dims(result, -1) + return result + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_batch_dot_fp32(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + np.random.seed(12876) + + # case 1 + shape_x1 = (3, 12, 5, 2, 3) + shape_x2 = (3, 1, 7, 3, 2) + axes = (-1, -2) + x1 = np.ones(shape=shape_x1).astype(np.float32) + x2 = np.ones(shape=shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 2 + shape_x1 = (4, 3, 7, 5) + shape_x2 = (4, 1, 7, 1) + axes = 2 + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 3 + shape_x1 = (18, 3, 5, 7) + shape_x2 = (18, 1, 3, 7) + axes = -1 + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 4 + shape_x1 = (2, 11, 3, 9) + shape_x2 = (2, 7, 9, 3) + axes = None + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 5 + shape_x1 = (7, 5) + shape_x2 = (7, 5) + axes = None + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 6 + shape_x1 = (7, 3, 5) + shape_x2 = (7, 5) + axes = None + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 7 + shape_x1 = (7, 5) + shape_x2 = (7, 5, 3) + axes = None + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 8 + shape_x1 = (39, 6) + shape_x2 = (39, 6) + axes = -1 + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + + assert np.allclose(ms_result_np, tf_result) + + # case 9 + shape_x1 = (21, 2, 3) + shape_x2 = (21, 3, 2) + axes = (-1, -2) + x1 = np.ones(shape=shape_x1).astype(np.float32) + x2 = np.ones(shape=shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + assert np.allclose(ms_result_np, tf_result) + + # case 10 + shape_x1 = (4, 3, 2, 1, 7, 5) + shape_x2 = (4, 5, 7, 1) + axes = -2 + x1 = np.ones(shape=shape_x1).astype(np.float32) + x2 = np.ones(shape=shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + assert np.allclose(ms_result_np, tf_result) + + # case 10 + shape_x1 = (4, 3, 2, 1, 7, 5) + shape_x2 = (4, 5, 7, 1) + axes = -2 + x1 = np.ones(shape=shape_x1).astype(np.float16) + x2 = np.ones(shape=shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + network = NetBatchDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + tf_result = _reference_batch_dot(x1, x2, axes) + assert np.allclose(ms_result_np, tf_result)