Add input check for axes which is float type or out of bound

This commit is contained in:
w00535372 2021-03-27 14:26:39 +08:00
parent 1892a629c8
commit 3fc2c16fea
1 changed files with 10 additions and 8 deletions

View File

@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
raise ValueError("Require two axes inputs, given less")
if isinstance(axes, tuple):
axes = list(axes)
for sub_axes in axes:
if isinstance(sub_axes, (list, tuple)):
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
# Reverse if axis < 0
if axes[0] < 0:
axes[0] += len(x1_shape)
if axes[1] < 0:
axes[1] += len(x2_shape)
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
elif isinstance(axes, int):
if axes == 0:
raise ValueError("Batch dim cannot be used as in axes.")
@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None):
"""
Computation of batch dot product between samples in two tensors containing batch dims.
.. math::
output = x1[batch, :] * x2[batch, :]
Inputs:
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None):
Outputs:
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
.. math::
output = x1[batch, :] * x2[batch, :]
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
Raises:
TypeError: If shapes of x1 and x2 are not the same.
TypeError: If type of x1 and x2 are not the same.
ValueError: If rank of x1 or x2 less than 2.
ValueError: If batch dim used in axes.
ValueError: If dtype of x1 or x2 is not float32.