forked from mindspore-Ecosystem/mindspore
!44491 fix: numpy slices dataset with diff shape datas
Merge pull request !44491 from guozhijian/fix_slice_dataset_with_diff_shape
This commit is contained in:
commit
281a8a8df5
|
@ -53,9 +53,10 @@ Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_ro
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
|
||||
int64_t avg_rows_per_file = num_rows / num_files;
|
||||
|
||||
*shuffle_size = shuffle_max;
|
||||
if (avg_rows_per_file != 0 && *shuffle_size > (avg_rows_per_file * average_files_multiplier)) {
|
||||
*shuffle_size = avg_rows_per_file * average_files_multiplier;
|
||||
if (avg_rows_per_file != 0) {
|
||||
*shuffle_size = std::min(avg_rows_per_file * average_files_multiplier, shuffle_max);
|
||||
} else {
|
||||
*shuffle_size = shuffle_max;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -776,10 +776,8 @@ class _NumpySlicesDataset:
|
|||
data = self.process_dict(data)
|
||||
|
||||
if isinstance(data, tuple):
|
||||
self.data = ()
|
||||
self.data = data
|
||||
data_len = len(data)
|
||||
for i in range(data_len):
|
||||
self.data = self.data + (np.array(data[i]),)
|
||||
else:
|
||||
self.data = (np.array(data),)
|
||||
|
||||
|
@ -798,7 +796,7 @@ class _NumpySlicesDataset:
|
|||
self.column_list.append("column_" + str(i))
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_row = [d[index, ...] for d in self.data]
|
||||
data_row = [d[index] for d in self.data]
|
||||
data_res = tuple(data_row)
|
||||
return data_res
|
||||
|
||||
|
|
|
@ -373,6 +373,50 @@ def test_numpy_slice_empty_output_shape():
|
|||
assert dataset.output_shapes() == []
|
||||
|
||||
|
||||
def test_numpy_slice_with_diff_shape():
|
||||
"""
|
||||
Feature: NumpySlicesDataset
|
||||
Description: Check if NumpySlicesDataset produces diff shape data
|
||||
Expectation: The dataset is processed success
|
||||
"""
|
||||
data1 = np.array([1], dtype=np.uint8)
|
||||
data2 = np.array([5, 6], dtype=np.uint8)
|
||||
data3 = np.array([9, 10, 11], dtype=np.uint8)
|
||||
data4 = np.array([13, 14, 15, 16], dtype=np.uint8)
|
||||
|
||||
data = [data1, data2, data3, data4]
|
||||
|
||||
label = [1, 2, 3, 4]
|
||||
|
||||
dataset = de.NumpySlicesDataset((data, label), ["data", "label"], num_shards=4, shard_id=2, shuffle=False)
|
||||
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
assert (item["data"] == data3).all()
|
||||
|
||||
|
||||
def test_numpy_slice_with_diff_shape_dict():
|
||||
"""
|
||||
Feature: NumpySlicesDataset
|
||||
Description: Check if NumpySlicesDataset produces diff shape data by dict
|
||||
Expectation: The dataset is processed success
|
||||
"""
|
||||
data1 = np.array([1], dtype=np.uint8)
|
||||
data2 = np.array([5, 6], dtype=np.uint8)
|
||||
data3 = np.array([9, 10, 11], dtype=np.uint8)
|
||||
data4 = np.array([13, 14, 15, 16], dtype=np.uint8)
|
||||
|
||||
data = [data1, data2, data3, data4]
|
||||
|
||||
label = [1, 2, 3, 4]
|
||||
|
||||
dict_data = {"data": data, "label": label}
|
||||
|
||||
dataset = de.NumpySlicesDataset(dict_data, num_shards=4, shard_id=2, shuffle=False)
|
||||
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
assert (item["data"] == data3).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_numpy_slices_list_1()
|
||||
test_numpy_slices_list_2()
|
||||
|
@ -394,3 +438,5 @@ if __name__ == "__main__":
|
|||
test_numpy_slices_invalid_empty_column_names()
|
||||
test_numpy_slices_invalid_empty_data_column()
|
||||
test_numpy_slice_empty_output_shape()
|
||||
test_numpy_slice_with_diff_shape()
|
||||
test_numpy_slice_with_diff_shape_dict()
|
||||
|
|
Loading…
Reference in New Issue