!20702 Fix NumpySlicesDataset not accept NumPy as data
Merge pull request !20702 from xiaotianci/fix_numpy_slices
This commit is contained in:
commit
bcb49f10b4
|
@ -1260,7 +1260,7 @@ def check_numpyslicesdataset(method):
|
|||
|
||||
data = param_dict.get("data")
|
||||
column_names = param_dict.get("column_names")
|
||||
if not data:
|
||||
if data is None or len(data) == 0: # pylint: disable=len-as-condition
|
||||
raise ValueError("Argument data cannot be empty")
|
||||
type_check(data, (list, tuple, dict, np.ndarray), "data")
|
||||
if isinstance(data, tuple):
|
||||
|
|
|
@ -51,6 +51,16 @@ def test_numpy_slices_list_3():
|
|||
assert np.equal(data[0].asnumpy(), np_data[i]).all()
|
||||
|
||||
|
||||
def test_numpy_slices_numpy():
|
||||
logger.info("Test NumPy structure data.")
|
||||
|
||||
np_data = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])
|
||||
ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
|
||||
|
||||
for i, data in enumerate(ds):
|
||||
assert np.equal(data[0].asnumpy(), np_data[i]).all()
|
||||
|
||||
|
||||
def test_numpy_slices_list_append():
|
||||
logger.info("Test reading data of image list.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue