forked from mindspore-Ecosystem/mindspore
Add input check for axes which is float type or out of bound
This commit is contained in:
parent
1892a629c8
commit
3fc2c16fea
|
@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
|||
raise ValueError("Require two axes inputs, given less")
|
||||
if isinstance(axes, tuple):
|
||||
axes = list(axes)
|
||||
for sub_axes in axes:
|
||||
if isinstance(sub_axes, (list, tuple)):
|
||||
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
|
||||
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
|
||||
validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
|
||||
# Reverse if axis < 0
|
||||
if axes[0] < 0:
|
||||
axes[0] += len(x1_shape)
|
||||
if axes[1] < 0:
|
||||
axes[1] += len(x2_shape)
|
||||
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
||||
raise ValueError(
|
||||
"Axes value too high for given input arrays dimensions.")
|
||||
elif isinstance(axes, int):
|
||||
if axes == 0:
|
||||
raise ValueError("Batch dim cannot be used as in axes.")
|
||||
|
@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None):
|
|||
"""
|
||||
Computation of batch dot product between samples in two tensors containing batch dims.
|
||||
|
||||
.. math::
|
||||
output = x1[batch, :] * x2[batch, :]
|
||||
|
||||
Inputs:
|
||||
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
|
||||
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
|
||||
|
@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None):
|
|||
|
||||
Outputs:
|
||||
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
|
||||
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
|
||||
|
||||
.. math::
|
||||
output = x1[batch, :] * x2[batch, :]
|
||||
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
|
||||
|
||||
Raises:
|
||||
TypeError: If shapes of x1 and x2 are not the same.
|
||||
TypeError: If type of x1 and x2 are not the same.
|
||||
ValueError: If rank of x1 or x2 less than 2.
|
||||
ValueError: If batch dim used in axes.
|
||||
ValueError: If dtype of x1 or x2 is not float32.
|
||||
|
|
Loading…
Reference in New Issue