forked from OSSInnovation/mindspore
!5929 BucketBatchByLength column issue
Merge pull request !5929 from MahdiRahmaniHanzaki/bucket_batch_by_length_fix
This commit is contained in:
commit
a778868a5a
|
@ -155,8 +155,17 @@ Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, T
|
|||
// call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of
|
||||
// the single column specified in length_dependent_columns_
|
||||
if (element_length_function_) {
|
||||
TensorRow output;
|
||||
RETURN_IF_NOT_OK(element_length_function_->Compute(element, &output));
|
||||
TensorRow input, output;
|
||||
size_t number_of_arguments = length_dependent_columns_.size();
|
||||
for (size_t i = 0; i < number_of_arguments; i++) {
|
||||
auto map_item = column_name_id_map_.find(length_dependent_columns_[i]);
|
||||
if (map_item == column_name_id_map_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("BucketBatchByLength: Couldn't find the specified column in the dataset");
|
||||
}
|
||||
int32_t column_index = map_item->second;
|
||||
input.push_back(element[column_index]);
|
||||
}
|
||||
RETURN_IF_NOT_OK(element_length_function_->Compute(input, &output));
|
||||
RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_element_length, {0}));
|
||||
if (*out_element_length < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("BucketBatchByLength: element_length_function returned negative integer");
|
||||
|
|
|
@ -36,6 +36,11 @@ def generate_2_columns(n):
|
|||
yield (np.array([i]), np.array([j for j in range(i + 1)]))
|
||||
|
||||
|
||||
def generate_3_columns(n):
|
||||
for i in range(n):
|
||||
yield (np.array([i]), np.array([i + 1]), np.array([j for j in range(i + 1)]))
|
||||
|
||||
|
||||
def test_bucket_batch_invalid_input():
|
||||
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
||||
|
||||
|
@ -382,6 +387,48 @@ def test_bucket_batch_multi_column():
|
|||
assert same_shape_output == same_shape_expected_output
|
||||
assert variable_shape_output == variable_shape_expected_output
|
||||
|
||||
|
||||
def test_bucket_batch_three_columns():
|
||||
dataset = ds.GeneratorDataset((lambda: generate_3_columns(10)), ["same_shape", "same_shape2", "variable_shape"])
|
||||
|
||||
column_names = ["same_shape2"]
|
||||
bucket_boundaries = [6, 12]
|
||||
bucket_batch_sizes = [5, 5, 1]
|
||||
element_length_function = (lambda x: x[0] % 3)
|
||||
pad_info = {}
|
||||
|
||||
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
|
||||
bucket_batch_sizes, element_length_function,
|
||||
pad_info)
|
||||
|
||||
same_shape_expected_output = [[[0], [1], [2], [3], [4]],
|
||||
[[5], [6], [7], [8], [9]]]
|
||||
same_shape2_expected_output = [[[1], [2], [3], [4], [5]],
|
||||
[[6], [7], [8], [9], [10]]]
|
||||
variable_shape_expected_output = [[[0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 1, 2, 0, 0],
|
||||
[0, 1, 2, 3, 0],
|
||||
[0, 1, 2, 3, 4]],
|
||||
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0],
|
||||
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]
|
||||
|
||||
same_shape_output = []
|
||||
same_shape2_output = []
|
||||
variable_shape_output = []
|
||||
for data in dataset.create_dict_iterator(num_epochs=1):
|
||||
same_shape_output.append(data["same_shape"].tolist())
|
||||
same_shape2_output.append(data["same_shape2"].tolist())
|
||||
variable_shape_output.append(data["variable_shape"].tolist())
|
||||
|
||||
assert same_shape_output == same_shape_expected_output
|
||||
assert same_shape2_output == same_shape2_expected_output
|
||||
assert variable_shape_output == variable_shape_expected_output
|
||||
|
||||
|
||||
def test_bucket_batch_get_dataset_size():
|
||||
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
||||
|
||||
|
@ -402,6 +449,25 @@ def test_bucket_batch_get_dataset_size():
|
|||
assert data_size == num_rows
|
||||
|
||||
|
||||
def test_bucket_batch_invalid_column():
|
||||
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
||||
|
||||
column_names = ["invalid_column"]
|
||||
bucket_boundaries = [1, 2, 3]
|
||||
bucket_batch_sizes = [3, 3, 2, 2]
|
||||
element_length_function = (lambda x: x[0] % 4)
|
||||
|
||||
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
|
||||
bucket_batch_sizes, element_length_function)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
num_rows = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
|
||||
assert "BucketBatchByLength: Couldn't find the specified column in the dataset" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bucket_batch_invalid_input()
|
||||
test_bucket_batch_multi_bucket_no_padding()
|
||||
|
@ -413,4 +479,6 @@ if __name__ == '__main__':
|
|||
test_bucket_batch_drop_remainder()
|
||||
test_bucket_batch_default_length_function()
|
||||
test_bucket_batch_multi_column()
|
||||
test_bucket_batch_three_columns()
|
||||
test_bucket_batch_get_dataset_size()
|
||||
test_bucket_batch_invalid_column()
|
||||
|
|
Loading…
Reference in New Issue