!13989 Fix float16 support problem and add raises for Batch Dot ops

From: @anrui-wang
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-03-25 14:23:39 +08:00 committed by Gitee
commit 06e8c677ee
2 changed files with 29 additions and 18 deletions

View File

@ -336,6 +336,17 @@ def _get_batch_size(x1_shape, x2_shape):
return x1_shape[0], x2_shape[0]
@constexpr
def _typecheck_input_batch_dot(x1_type, x2_type):
"""
Check input tensor types to be valid and confirm they are the same type for batch dot ops.
"""
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
if x1_type != x2_type:
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
@constexpr
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
"""
@ -419,15 +430,29 @@ def batch_dot(x1, x2, axes=None):
Computation of batch dot product between samples in two tensors containing batch dims.
Inputs:
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
be same as x1's.
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
Outputs:
Tensor, batch dot product of x1 and x2.
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
.. math::
output = x1[batch, :] * x2[batch, :]
Raises:
TypeError: If shapes of x1 and x2 are not the same.
ValueError: If rank of x1 or x2 less than 2.
ValueError: If batch dim used in axes.
ValueError: If dtype of x1 or x2 is not float32.
ValueError: If len(axes) less than 2.
ValueError: If axes is not one of those: None, int, (int, int).
ValueError: If axes value is too high for dimensions of input arrays.
ValueError: If batch size of x1 and x2 are not the same.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -458,7 +483,7 @@ def batch_dot(x1, x2, axes=None):
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
_typecheck_input(x1_type, x2_type)
_typecheck_input_batch_dot(x1_type, x2_type)
_check_batch_size(x1_batch_size, x2_batch_size)
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)

View File

@ -211,17 +211,3 @@ def test_batch_dot_fp32():
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)
# case 10
shape_x1 = (4, 3, 2, 1, 7, 5)
shape_x2 = (4, 5, 7, 1)
axes = -2
x1 = np.ones(shape=shape_x1).astype(np.float16)
x2 = np.ones(shape=shape_x2).astype(np.float16)
x1_tensor = Tensor(x1, dtype=mindspore.float16)
x2_tensor = Tensor(x2, dtype=mindspore.float16)
network = NetBatchDot(axes)
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
tf_result = _reference_batch_dot(x1, x2, axes)
assert np.allclose(ms_result_np, tf_result)