!44491 fix: numpy slices dataset with diff shape datas

Merge pull request !44491 from guozhijian/fix_slice_dataset_with_diff_shape
This commit is contained in:
i-robot 2022-10-26 06:41:42 +00:00 committed by Gitee
commit 281a8a8df5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 52 additions and 7 deletions

View File

@ -53,9 +53,10 @@ Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_ro
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0."); CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
int64_t avg_rows_per_file = num_rows / num_files; int64_t avg_rows_per_file = num_rows / num_files;
*shuffle_size = shuffle_max; if (avg_rows_per_file != 0) {
if (avg_rows_per_file != 0 && *shuffle_size > (avg_rows_per_file * average_files_multiplier)) { *shuffle_size = std::min(avg_rows_per_file * average_files_multiplier, shuffle_max);
*shuffle_size = avg_rows_per_file * average_files_multiplier; } else {
*shuffle_size = shuffle_max;
} }
return Status::OK(); return Status::OK();

View File

@ -776,10 +776,8 @@ class _NumpySlicesDataset:
data = self.process_dict(data) data = self.process_dict(data)
if isinstance(data, tuple): if isinstance(data, tuple):
self.data = () self.data = data
data_len = len(data) data_len = len(data)
for i in range(data_len):
self.data = self.data + (np.array(data[i]),)
else: else:
self.data = (np.array(data),) self.data = (np.array(data),)
@ -798,7 +796,7 @@ class _NumpySlicesDataset:
self.column_list.append("column_" + str(i)) self.column_list.append("column_" + str(i))
def __getitem__(self, index): def __getitem__(self, index):
data_row = [d[index, ...] for d in self.data] data_row = [d[index] for d in self.data]
data_res = tuple(data_row) data_res = tuple(data_row)
return data_res return data_res

View File

@ -373,6 +373,50 @@ def test_numpy_slice_empty_output_shape():
assert dataset.output_shapes() == [] assert dataset.output_shapes() == []
def test_numpy_slice_with_diff_shape():
"""
Feature: NumpySlicesDataset
Description: Check if NumpySlicesDataset produces diff shape data
Expectation: The dataset is processed success
"""
data1 = np.array([1], dtype=np.uint8)
data2 = np.array([5, 6], dtype=np.uint8)
data3 = np.array([9, 10, 11], dtype=np.uint8)
data4 = np.array([13, 14, 15, 16], dtype=np.uint8)
data = [data1, data2, data3, data4]
label = [1, 2, 3, 4]
dataset = de.NumpySlicesDataset((data, label), ["data", "label"], num_shards=4, shard_id=2, shuffle=False)
for item in dataset.create_dict_iterator(output_numpy=True):
assert (item["data"] == data3).all()
def test_numpy_slice_with_diff_shape_dict():
"""
Feature: NumpySlicesDataset
Description: Check if NumpySlicesDataset produces diff shape data by dict
Expectation: The dataset is processed success
"""
data1 = np.array([1], dtype=np.uint8)
data2 = np.array([5, 6], dtype=np.uint8)
data3 = np.array([9, 10, 11], dtype=np.uint8)
data4 = np.array([13, 14, 15, 16], dtype=np.uint8)
data = [data1, data2, data3, data4]
label = [1, 2, 3, 4]
dict_data = {"data": data, "label": label}
dataset = de.NumpySlicesDataset(dict_data, num_shards=4, shard_id=2, shuffle=False)
for item in dataset.create_dict_iterator(output_numpy=True):
assert (item["data"] == data3).all()
if __name__ == "__main__": if __name__ == "__main__":
test_numpy_slices_list_1() test_numpy_slices_list_1()
test_numpy_slices_list_2() test_numpy_slices_list_2()
@ -394,3 +438,5 @@ if __name__ == "__main__":
test_numpy_slices_invalid_empty_column_names() test_numpy_slices_invalid_empty_column_names()
test_numpy_slices_invalid_empty_data_column() test_numpy_slices_invalid_empty_data_column()
test_numpy_slice_empty_output_shape() test_numpy_slice_empty_output_shape()
test_numpy_slice_with_diff_shape()
test_numpy_slice_with_diff_shape_dict()