diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index e6ec21b9bf9..04c60644ac9 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -336,6 +336,17 @@ def _get_batch_size(x1_shape, x2_shape): return x1_shape[0], x2_shape[0] +@constexpr +def _typecheck_input_batch_dot(x1_type, x2_type): + """ + Check input tensor types to be valid and confirm they are the same type for batch dot ops. + """ + const_utils.check_type_valid(x1_type, [mstype.float32], 'x1') + const_utils.check_type_valid(x2_type, [mstype.float32], 'x2') + if x1_type != x2_type: + raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') + + @constexpr def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): """ @@ -419,15 +430,29 @@ def batch_dot(x1, x2, axes=None): Computation of batch dot product between samples in two tensors containing batch dims. Inputs: - - **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32 - - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should + - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 + - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should be same as x1's. - **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from `a` input shape and last N dims from `b` input shape in order as axes for each respectively. Outputs: - Tensor, batch dot product of x1 and x2. + Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and + (batch, d3, axes, d4) is (batch, d1, d2, d3, d4) + + .. math:: + output = x1[batch, :] * x2[batch, :] + + Raises: + TypeError: If shapes of x1 and x2 are not the same. + ValueError: If rank of x1 or x2 less than 2. + ValueError: If batch dim used in axes. + ValueError: If dtype of x1 or x2 is not float32. + ValueError: If len(axes) less than 2. + ValueError: If axes is not one of those: None, int, (int, int). + ValueError: If axes value is too high for dimensions of input arrays. + ValueError: If batch size of x1 and x2 are not the same. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -458,7 +483,7 @@ def batch_dot(x1, x2, axes=None): x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape) - _typecheck_input(x1_type, x2_type) + _typecheck_input_batch_dot(x1_type, x2_type) _check_batch_size(x1_batch_size, x2_batch_size) axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes) diff --git a/tests/st/ops/cpu/test_batchdot_op.py b/tests/st/ops/cpu/test_batchdot_op.py index e7817f2a682..83fec0305b8 100644 --- a/tests/st/ops/cpu/test_batchdot_op.py +++ b/tests/st/ops/cpu/test_batchdot_op.py @@ -211,17 +211,3 @@ def test_batch_dot_fp32(): ms_result_np = network(x1_tensor, x2_tensor).asnumpy() tf_result = _reference_batch_dot(x1, x2, axes) assert np.allclose(ms_result_np, tf_result) - - # case 10 - shape_x1 = (4, 3, 2, 1, 7, 5) - shape_x2 = (4, 5, 7, 1) - axes = -2 - x1 = np.ones(shape=shape_x1).astype(np.float16) - x2 = np.ones(shape=shape_x2).astype(np.float16) - x1_tensor = Tensor(x1, dtype=mindspore.float16) - x2_tensor = Tensor(x2, dtype=mindspore.float16) - - network = NetBatchDot(axes) - ms_result_np = network(x1_tensor, x2_tensor).asnumpy() - tf_result = _reference_batch_dot(x1, x2, axes) - assert np.allclose(ms_result_np, tf_result)