forked from mindspore-Ecosystem/mindspore
!44491 fix: numpy slices dataset with diff shape datas
Merge pull request !44491 from guozhijian/fix_slice_dataset_with_diff_shape
This commit is contained in:
commit
281a8a8df5
|
@ -53,9 +53,10 @@ Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_ro
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
|
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
|
||||||
int64_t avg_rows_per_file = num_rows / num_files;
|
int64_t avg_rows_per_file = num_rows / num_files;
|
||||||
|
|
||||||
|
if (avg_rows_per_file != 0) {
|
||||||
|
*shuffle_size = std::min(avg_rows_per_file * average_files_multiplier, shuffle_max);
|
||||||
|
} else {
|
||||||
*shuffle_size = shuffle_max;
|
*shuffle_size = shuffle_max;
|
||||||
if (avg_rows_per_file != 0 && *shuffle_size > (avg_rows_per_file * average_files_multiplier)) {
|
|
||||||
*shuffle_size = avg_rows_per_file * average_files_multiplier;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -776,10 +776,8 @@ class _NumpySlicesDataset:
|
||||||
data = self.process_dict(data)
|
data = self.process_dict(data)
|
||||||
|
|
||||||
if isinstance(data, tuple):
|
if isinstance(data, tuple):
|
||||||
self.data = ()
|
self.data = data
|
||||||
data_len = len(data)
|
data_len = len(data)
|
||||||
for i in range(data_len):
|
|
||||||
self.data = self.data + (np.array(data[i]),)
|
|
||||||
else:
|
else:
|
||||||
self.data = (np.array(data),)
|
self.data = (np.array(data),)
|
||||||
|
|
||||||
|
@ -798,7 +796,7 @@ class _NumpySlicesDataset:
|
||||||
self.column_list.append("column_" + str(i))
|
self.column_list.append("column_" + str(i))
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
data_row = [d[index, ...] for d in self.data]
|
data_row = [d[index] for d in self.data]
|
||||||
data_res = tuple(data_row)
|
data_res = tuple(data_row)
|
||||||
return data_res
|
return data_res
|
||||||
|
|
||||||
|
|
|
@ -373,6 +373,50 @@ def test_numpy_slice_empty_output_shape():
|
||||||
assert dataset.output_shapes() == []
|
assert dataset.output_shapes() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_slice_with_diff_shape():
|
||||||
|
"""
|
||||||
|
Feature: NumpySlicesDataset
|
||||||
|
Description: Check if NumpySlicesDataset produces diff shape data
|
||||||
|
Expectation: The dataset is processed success
|
||||||
|
"""
|
||||||
|
data1 = np.array([1], dtype=np.uint8)
|
||||||
|
data2 = np.array([5, 6], dtype=np.uint8)
|
||||||
|
data3 = np.array([9, 10, 11], dtype=np.uint8)
|
||||||
|
data4 = np.array([13, 14, 15, 16], dtype=np.uint8)
|
||||||
|
|
||||||
|
data = [data1, data2, data3, data4]
|
||||||
|
|
||||||
|
label = [1, 2, 3, 4]
|
||||||
|
|
||||||
|
dataset = de.NumpySlicesDataset((data, label), ["data", "label"], num_shards=4, shard_id=2, shuffle=False)
|
||||||
|
|
||||||
|
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||||
|
assert (item["data"] == data3).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_slice_with_diff_shape_dict():
|
||||||
|
"""
|
||||||
|
Feature: NumpySlicesDataset
|
||||||
|
Description: Check if NumpySlicesDataset produces diff shape data by dict
|
||||||
|
Expectation: The dataset is processed success
|
||||||
|
"""
|
||||||
|
data1 = np.array([1], dtype=np.uint8)
|
||||||
|
data2 = np.array([5, 6], dtype=np.uint8)
|
||||||
|
data3 = np.array([9, 10, 11], dtype=np.uint8)
|
||||||
|
data4 = np.array([13, 14, 15, 16], dtype=np.uint8)
|
||||||
|
|
||||||
|
data = [data1, data2, data3, data4]
|
||||||
|
|
||||||
|
label = [1, 2, 3, 4]
|
||||||
|
|
||||||
|
dict_data = {"data": data, "label": label}
|
||||||
|
|
||||||
|
dataset = de.NumpySlicesDataset(dict_data, num_shards=4, shard_id=2, shuffle=False)
|
||||||
|
|
||||||
|
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||||
|
assert (item["data"] == data3).all()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_numpy_slices_list_1()
|
test_numpy_slices_list_1()
|
||||||
test_numpy_slices_list_2()
|
test_numpy_slices_list_2()
|
||||||
|
@ -394,3 +438,5 @@ if __name__ == "__main__":
|
||||||
test_numpy_slices_invalid_empty_column_names()
|
test_numpy_slices_invalid_empty_column_names()
|
||||||
test_numpy_slices_invalid_empty_data_column()
|
test_numpy_slices_invalid_empty_data_column()
|
||||||
test_numpy_slice_empty_output_shape()
|
test_numpy_slice_empty_output_shape()
|
||||||
|
test_numpy_slice_with_diff_shape()
|
||||||
|
test_numpy_slice_with_diff_shape_dict()
|
||||||
|
|
Loading…
Reference in New Issue