!5929 BucketBatchByLength column issue

Merge pull request !5929 from MahdiRahmaniHanzaki/bucket_batch_by_length_fix
This commit is contained in:
mindspore-ci-bot 2020-09-11 05:53:48 +08:00 committed by Gitee
commit a778868a5a
2 changed files with 79 additions and 2 deletions

View File

@ -155,8 +155,17 @@ Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, T
// call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of
// the single column specified in length_dependent_columns_
if (element_length_function_) {
TensorRow output;
RETURN_IF_NOT_OK(element_length_function_->Compute(element, &output));
TensorRow input, output;
size_t number_of_arguments = length_dependent_columns_.size();
for (size_t i = 0; i < number_of_arguments; i++) {
auto map_item = column_name_id_map_.find(length_dependent_columns_[i]);
if (map_item == column_name_id_map_.end()) {
RETURN_STATUS_UNEXPECTED("BucketBatchByLength: Couldn't find the specified column in the dataset");
}
int32_t column_index = map_item->second;
input.push_back(element[column_index]);
}
RETURN_IF_NOT_OK(element_length_function_->Compute(input, &output));
RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_element_length, {0}));
if (*out_element_length < 0) {
RETURN_STATUS_UNEXPECTED("BucketBatchByLength: element_length_function returned negative integer");

View File

@ -36,6 +36,11 @@ def generate_2_columns(n):
yield (np.array([i]), np.array([j for j in range(i + 1)]))
def generate_3_columns(n):
for i in range(n):
yield (np.array([i]), np.array([i + 1]), np.array([j for j in range(i + 1)]))
def test_bucket_batch_invalid_input():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
@ -382,6 +387,48 @@ def test_bucket_batch_multi_column():
assert same_shape_output == same_shape_expected_output
assert variable_shape_output == variable_shape_expected_output
def test_bucket_batch_three_columns():
dataset = ds.GeneratorDataset((lambda: generate_3_columns(10)), ["same_shape", "same_shape2", "variable_shape"])
column_names = ["same_shape2"]
bucket_boundaries = [6, 12]
bucket_batch_sizes = [5, 5, 1]
element_length_function = (lambda x: x[0] % 3)
pad_info = {}
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)
same_shape_expected_output = [[[0], [1], [2], [3], [4]],
[[5], [6], [7], [8], [9]]]
same_shape2_expected_output = [[[1], [2], [3], [4], [5]],
[[6], [7], [8], [9], [10]]]
variable_shape_expected_output = [[[0, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 1, 2, 0, 0],
[0, 1, 2, 3, 0],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]
same_shape_output = []
same_shape2_output = []
variable_shape_output = []
for data in dataset.create_dict_iterator(num_epochs=1):
same_shape_output.append(data["same_shape"].tolist())
same_shape2_output.append(data["same_shape2"].tolist())
variable_shape_output.append(data["variable_shape"].tolist())
assert same_shape_output == same_shape_expected_output
assert same_shape2_output == same_shape2_expected_output
assert variable_shape_output == variable_shape_expected_output
def test_bucket_batch_get_dataset_size():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
@ -402,6 +449,25 @@ def test_bucket_batch_get_dataset_size():
assert data_size == num_rows
def test_bucket_batch_invalid_column():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
column_names = ["invalid_column"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [3, 3, 2, 2]
element_length_function = (lambda x: x[0] % 4)
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function)
with pytest.raises(RuntimeError) as info:
num_rows = 0
for _ in dataset.create_dict_iterator(num_epochs=1):
num_rows += 1
assert "BucketBatchByLength: Couldn't find the specified column in the dataset" in str(info.value)
if __name__ == '__main__':
test_bucket_batch_invalid_input()
test_bucket_batch_multi_bucket_no_padding()
@ -413,4 +479,6 @@ if __name__ == '__main__':
test_bucket_batch_drop_remainder()
test_bucket_batch_default_length_function()
test_bucket_batch_multi_column()
test_bucket_batch_three_columns()
test_bucket_batch_get_dataset_size()
test_bucket_batch_invalid_column()