forked from mindspore-Ecosystem/mindspore
!13989 Fix float16 support problem and add raises for Batch Dot ops
From: @anrui-wang Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian
This commit is contained in:
commit
06e8c677ee
|
@ -336,6 +336,17 @@ def _get_batch_size(x1_shape, x2_shape):
|
|||
return x1_shape[0], x2_shape[0]
|
||||
|
||||
|
||||
@constexpr
|
||||
def _typecheck_input_batch_dot(x1_type, x2_type):
|
||||
"""
|
||||
Check input tensor types to be valid and confirm they are the same type for batch dot ops.
|
||||
"""
|
||||
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
|
||||
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
|
||||
if x1_type != x2_type:
|
||||
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
||||
"""
|
||||
|
@ -419,15 +430,29 @@ def batch_dot(x1, x2, axes=None):
|
|||
Computation of batch dot product between samples in two tensors containing batch dims.
|
||||
|
||||
Inputs:
|
||||
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
|
||||
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
|
||||
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
|
||||
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
|
||||
be same as x1's.
|
||||
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
|
||||
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
|
||||
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
|
||||
|
||||
Outputs:
|
||||
Tensor, batch dot product of x1 and x2.
|
||||
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
|
||||
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
|
||||
|
||||
.. math::
|
||||
output = x1[batch, :] * x2[batch, :]
|
||||
|
||||
Raises:
|
||||
TypeError: If shapes of x1 and x2 are not the same.
|
||||
ValueError: If rank of x1 or x2 less than 2.
|
||||
ValueError: If batch dim used in axes.
|
||||
ValueError: If dtype of x1 or x2 is not float32.
|
||||
ValueError: If len(axes) less than 2.
|
||||
ValueError: If axes is not one of those: None, int, (int, int).
|
||||
ValueError: If axes value is too high for dimensions of input arrays.
|
||||
ValueError: If batch size of x1 and x2 are not the same.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -458,7 +483,7 @@ def batch_dot(x1, x2, axes=None):
|
|||
|
||||
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
|
||||
|
||||
_typecheck_input(x1_type, x2_type)
|
||||
_typecheck_input_batch_dot(x1_type, x2_type)
|
||||
_check_batch_size(x1_batch_size, x2_batch_size)
|
||||
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
|
||||
|
||||
|
|
|
@ -211,17 +211,3 @@ def test_batch_dot_fp32():
|
|||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||
assert np.allclose(ms_result_np, tf_result)
|
||||
|
||||
# case 10
|
||||
shape_x1 = (4, 3, 2, 1, 7, 5)
|
||||
shape_x2 = (4, 5, 7, 1)
|
||||
axes = -2
|
||||
x1 = np.ones(shape=shape_x1).astype(np.float16)
|
||||
x2 = np.ones(shape=shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
network = NetBatchDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
tf_result = _reference_batch_dot(x1, x2, axes)
|
||||
assert np.allclose(ms_result_np, tf_result)
|
||||
|
|
Loading…
Reference in New Issue