forked from mindspore-Ecosystem/mindspore
Cleanup work for Concate, Mask, Slice, PadEnd and TruncatePair
This commit is contained in:
parent
bc4b1c2460
commit
674415f7be
|
@ -403,7 +403,7 @@ def check_to_number(method):
|
|||
if not isinstance(data_type, typing.Type):
|
||||
raise TypeError("data_type is not a MindSpore data type.")
|
||||
|
||||
if not data_type in mstype.number_type:
|
||||
if data_type not in mstype.number_type:
|
||||
raise TypeError("data_type is not numeric data type.")
|
||||
|
||||
kwargs["data_type"] = data_type
|
||||
|
|
|
@ -79,12 +79,13 @@ class Slice(cde.SliceOp):
|
|||
(Currently only rank 1 Tensors are supported)
|
||||
|
||||
Args:
|
||||
*slices: Maximum n number of objects to slice a tensor of rank n.
|
||||
One object in slices can be one of:
|
||||
*slices(Variable length argument list): Maximum `n` number of arguments to slice a tensor of rank `n`.
|
||||
One object in slices can be one of:
|
||||
1. int: slice this index only. Negative index is supported.
|
||||
2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`.
|
||||
3. None: slice the whole dimension. Similar to `:` in python indexing.
|
||||
4. Ellipses ...: slice all dimensions between the two slices.
|
||||
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | col |
|
||||
|
@ -134,11 +135,13 @@ class Mask(cde.MaskOp):
|
|||
"""
|
||||
Mask content of the input tensor with the given predicate.
|
||||
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
|
||||
|
||||
Args:
|
||||
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
|
||||
constant (python types (str, int, float, or bool): constant to be compared to.
|
||||
Constant will be casted to the type of the input tensor
|
||||
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
|
||||
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | col1 |
|
||||
|
@ -163,11 +166,13 @@ class Mask(cde.MaskOp):
|
|||
class PadEnd(cde.PadEndOp):
|
||||
"""
|
||||
Pad input tensor according to `pad_shape`, need to have same rank.
|
||||
|
||||
Args:
|
||||
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
|
||||
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
|
||||
pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty
|
||||
string in case of Tensors of strings.
|
||||
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | col |
|
||||
|
@ -201,21 +206,25 @@ class Concatenate(cde.ConcatenateOp):
|
|||
|
||||
@check_concat_type
|
||||
def __init__(self, axis=0, prepend=None, append=None):
|
||||
# add some validations here later
|
||||
if prepend is not None:
|
||||
prepend = cde.Tensor(np.array(prepend))
|
||||
if append is not None:
|
||||
append = cde.Tensor(np.array(append))
|
||||
super().__init__(axis, prepend, append)
|
||||
|
||||
|
||||
class Duplicate(cde.DuplicateOp):
|
||||
"""
|
||||
Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list.
|
||||
Examples:
|
||||
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | x |
|
||||
>>> # +---------+
|
||||
>>> # | [1,2,3] |
|
||||
>>> # +---------+
|
||||
>>> data = data.map(input_columns=["x"], operations=Duplicate(),
|
||||
>>> output_columns=["x", "y"], output_order=["x", "y"])
|
||||
>>> output_columns=["x", "y"], columns_order=["x", "y"])
|
||||
>>> # Data after
|
||||
>>> # | x | y |
|
||||
>>> # +---------+---------+
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
from functools import wraps
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
from mindspore._c_expression import typing
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
|
@ -243,12 +242,13 @@ def check_mask_op(method):
|
|||
if not isinstance(constant, (str, float, bool, int, bytes)):
|
||||
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
|
||||
|
||||
if not isinstance(dtype, typing.Type):
|
||||
raise TypeError("dtype is not a MindSpore data type.")
|
||||
if dtype is not None:
|
||||
if not isinstance(dtype, typing.Type):
|
||||
raise TypeError("dtype is not a MindSpore data type.")
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
kwargs["operator"] = operator
|
||||
kwargs["constant"] = constant
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
@ -269,8 +269,10 @@ def check_pad_end(method):
|
|||
if pad_shape is None:
|
||||
raise ValueError("pad_shape is not provided.")
|
||||
|
||||
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)):
|
||||
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes.")
|
||||
if pad_value is not None:
|
||||
if not isinstance(pad_value, (str, float, bool, int, bytes)):
|
||||
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
|
||||
kwargs["pad_value"] = pad_value
|
||||
|
||||
if not isinstance(pad_shape, list):
|
||||
raise TypeError("pad_shape must be a list")
|
||||
|
@ -283,7 +285,6 @@ def check_pad_end(method):
|
|||
raise TypeError("a value in the list is not an integer.")
|
||||
|
||||
kwargs["pad_shape"] = pad_shape
|
||||
kwargs["pad_value"] = pad_value
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
@ -303,30 +304,22 @@ def check_concat_type(method):
|
|||
if "axis" in kwargs:
|
||||
axis = kwargs.get("axis")
|
||||
|
||||
if not isinstance(axis, (type(None), int)):
|
||||
raise TypeError("axis type is not valid, must be None or an integer.")
|
||||
if axis is not None:
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError("axis type is not valid, must be an integer.")
|
||||
if axis not in (0, -1):
|
||||
raise ValueError("only 1D concatenation supported.")
|
||||
kwargs["axis"] = axis
|
||||
|
||||
if isinstance(axis, type(None)):
|
||||
axis = 0
|
||||
if prepend is not None:
|
||||
if not isinstance(prepend, (type(None), np.ndarray)):
|
||||
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
|
||||
kwargs["prepend"] = prepend
|
||||
|
||||
if axis not in (None, 0, -1):
|
||||
raise ValueError("only 1D concatenation supported.")
|
||||
|
||||
if not isinstance(prepend, (type(None), np.ndarray)):
|
||||
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
|
||||
|
||||
if not isinstance(append, (type(None), np.ndarray)):
|
||||
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
|
||||
|
||||
if isinstance(prepend, np.ndarray):
|
||||
prepend = cde.Tensor(prepend)
|
||||
|
||||
if isinstance(append, np.ndarray):
|
||||
append = cde.Tensor(append)
|
||||
|
||||
kwargs["axis"] = axis
|
||||
kwargs["prepend"] = prepend
|
||||
kwargs["append"] = append
|
||||
if append is not None:
|
||||
if not isinstance(append, (type(None), np.ndarray)):
|
||||
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
|
||||
kwargs["append"] = append
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def mask_compare(array, op, constant, dtype=mstype.bool_):
|
|||
np.testing.assert_array_equal(array, d[0])
|
||||
|
||||
|
||||
def test_int_comparison():
|
||||
def test_mask_int_comparison():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
|
@ -74,7 +74,7 @@ def test_int_comparison():
|
|||
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k)
|
||||
|
||||
|
||||
def test_float_comparison():
|
||||
def test_mask_float_comparison():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
|
@ -86,7 +86,7 @@ def test_float_comparison():
|
|||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k)
|
||||
|
||||
|
||||
def test_float_comparison2():
|
||||
def test_mask_float_comparison2():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
|
@ -98,7 +98,7 @@ def test_float_comparison2():
|
|||
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k)
|
||||
|
||||
|
||||
def test_string_comparison():
|
||||
def test_mask_string_comparison():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
|
@ -125,8 +125,8 @@ def test_mask_exceptions_str():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_int_comparison()
|
||||
test_float_comparison()
|
||||
test_float_comparison2()
|
||||
test_string_comparison()
|
||||
test_mask_int_comparison()
|
||||
test_mask_float_comparison()
|
||||
test_mask_float_comparison2()
|
||||
test_mask_string_comparison()
|
||||
test_mask_exceptions_str()
|
||||
|
|
Loading…
Reference in New Issue