forked from mindspore-Ecosystem/mindspore
add missing test
This commit is contained in:
parent
722eafcac6
commit
d89101b95f
|
@ -576,7 +576,7 @@ Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
|
||||
|
||||
Tensor::HandleNeg(axis, input[0]->shape().Rank());
|
||||
axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
|
||||
|
||||
std::shared_ptr<Tensor> out;
|
||||
|
|
|
@ -166,8 +166,8 @@ class PadEnd(cde.PadEndOp):
|
|||
Args:
|
||||
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
|
||||
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
|
||||
pad_value (str, bytes, int, float, or bool, optional): value used to pad. Default to 0 or empty string in case
|
||||
of Tensors of strings.
|
||||
pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty
|
||||
string in case of Tensors of strings.
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | col |
|
||||
|
|
|
@ -233,10 +233,10 @@ def check_mask_op(method):
|
|||
if operator is None:
|
||||
raise ValueError("operator is not provided.")
|
||||
|
||||
from .c_transforms import Relational
|
||||
if constant is None:
|
||||
raise ValueError("constant is not provided.")
|
||||
|
||||
from .c_transforms import Relational
|
||||
if not isinstance(operator, Relational):
|
||||
raise TypeError("operator is not a Relational operator enum.")
|
||||
|
||||
|
@ -270,14 +270,17 @@ def check_pad_end(method):
|
|||
raise ValueError("pad_shape is not provided.")
|
||||
|
||||
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)):
|
||||
raise TypeError("pad_value must be either a primitive python str, float, bool, bytes or int")
|
||||
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes.")
|
||||
|
||||
if not isinstance(pad_shape, list):
|
||||
raise TypeError("pad_shape must be a list")
|
||||
|
||||
for dim in pad_shape:
|
||||
if dim is not None:
|
||||
check_pos_int64(dim)
|
||||
if isinstance(dim, int):
|
||||
check_pos_int64(dim)
|
||||
else:
|
||||
raise TypeError("a value in the list is not an integer.")
|
||||
|
||||
kwargs["pad_shape"] = pad_shape
|
||||
kwargs["pad_value"] = pad_value
|
||||
|
|
|
@ -147,6 +147,21 @@ def test_concatenate_op_wrong_axis():
|
|||
assert "only 1D concatenation supported." in repr(error_info.value)
|
||||
|
||||
|
||||
def test_concatenate_op_negative_axis():
|
||||
def gen():
|
||||
yield (np.array([5., 6., 7., 8.], dtype=np.float),)
|
||||
|
||||
prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
|
||||
append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor)
|
||||
data = data.map(input_columns=["col"], operations=concatenate_op)
|
||||
expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
|
||||
11., 12.])
|
||||
for data_row in data:
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
def test_concatenate_op_incorrect_input_dim():
|
||||
def gen():
|
||||
yield (np.array(["ss", "ad"], dtype='S'),)
|
||||
|
@ -166,10 +181,11 @@ if __name__ == "__main__":
|
|||
test_concatenate_op_all()
|
||||
test_concatenate_op_none()
|
||||
test_concatenate_op_string()
|
||||
test_concatenate_op_multi_input_string()
|
||||
test_concatenate_op_multi_input_numeric()
|
||||
test_concatenate_op_type_mismatch()
|
||||
test_concatenate_op_type_mismatch2()
|
||||
test_concatenate_op_incorrect_dim()
|
||||
test_concatenate_op_incorrect_input_dim()
|
||||
test_concatenate_op_multi_input_numeric()
|
||||
test_concatenate_op_multi_input_string()
|
||||
test_concatenate_op_negative_axis()
|
||||
test_concatenate_op_wrong_axis()
|
||||
test_concatenate_op_incorrect_input_dim()
|
||||
|
|
|
@ -22,6 +22,8 @@ import mindspore.dataset as ds
|
|||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
|
||||
|
||||
# Extensive testing of PadEnd is already done in batch with Pad test cases
|
||||
|
||||
def pad_compare(array, pad_shape, pad_value, res):
|
||||
data = ds.NumpySlicesDataset([array])
|
||||
if pad_value is not None:
|
||||
|
@ -32,8 +34,6 @@ def pad_compare(array, pad_shape, pad_value, res):
|
|||
np.testing.assert_array_equal(res, d[0])
|
||||
|
||||
|
||||
# Extensive testing of PadEnd is already done in batch with Pad test cases
|
||||
|
||||
def test_pad_end_basics():
|
||||
pad_compare([1, 2], [3], -1, [1, 2, -1])
|
||||
pad_compare([1, 2, 3], [3], -1, [1, 2, 3])
|
||||
|
@ -57,6 +57,10 @@ def test_pad_end_exceptions():
|
|||
pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, [])
|
||||
assert "Source and pad_value tensors are not of the same type." in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
pad_compare([3, 4, 5], ["2"], 1, [])
|
||||
assert "a value in the list is not an integer." in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pad_end_basics()
|
Loading…
Reference in New Issue