diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc index 9f4866f92e6..30da571089e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -155,8 +155,17 @@ Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, T // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of // the single column specified in length_dependent_columns_ if (element_length_function_) { - TensorRow output; - RETURN_IF_NOT_OK(element_length_function_->Compute(element, &output)); + TensorRow input, output; + size_t number_of_arguments = length_dependent_columns_.size(); + for (size_t i = 0; i < number_of_arguments; i++) { + auto map_item = column_name_id_map_.find(length_dependent_columns_[i]); + if (map_item == column_name_id_map_.end()) { + RETURN_STATUS_UNEXPECTED("BucketBatchByLength: Couldn't find the specified column in the dataset"); + } + int32_t column_index = map_item->second; + input.push_back(element[column_index]); + } + RETURN_IF_NOT_OK(element_length_function_->Compute(input, &output)); RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_element_length, {0})); if (*out_element_length < 0) { RETURN_STATUS_UNEXPECTED("BucketBatchByLength: element_length_function returned negative integer"); diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index d0f102b8aef..bc5993fd212 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -36,6 +36,11 @@ def generate_2_columns(n): yield (np.array([i]), np.array([j for j in range(i + 1)])) +def generate_3_columns(n): + for i in range(n): + yield (np.array([i]), np.array([i + 1]), np.array([j for j in range(i + 1)])) + + def test_bucket_batch_invalid_input(): dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) @@ -382,6 +387,48 @@ def test_bucket_batch_multi_column(): assert same_shape_output == same_shape_expected_output assert variable_shape_output == variable_shape_expected_output + +def test_bucket_batch_three_columns(): + dataset = ds.GeneratorDataset((lambda: generate_3_columns(10)), ["same_shape", "same_shape2", "variable_shape"]) + + column_names = ["same_shape2"] + bucket_boundaries = [6, 12] + bucket_batch_sizes = [5, 5, 1] + element_length_function = (lambda x: x[0] % 3) + pad_info = {} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + same_shape_expected_output = [[[0], [1], [2], [3], [4]], + [[5], [6], [7], [8], [9]]] + same_shape2_expected_output = [[[1], [2], [3], [4], [5]], + [[6], [7], [8], [9], [10]]] + variable_shape_expected_output = [[[0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] + + same_shape_output = [] + same_shape2_output = [] + variable_shape_output = [] + for data in dataset.create_dict_iterator(num_epochs=1): + same_shape_output.append(data["same_shape"].tolist()) + same_shape2_output.append(data["same_shape2"].tolist()) + variable_shape_output.append(data["variable_shape"].tolist()) + + assert same_shape_output == same_shape_expected_output + assert same_shape2_output == same_shape2_expected_output + assert variable_shape_output == variable_shape_expected_output + + def test_bucket_batch_get_dataset_size(): dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) @@ -402,6 +449,25 @@ def test_bucket_batch_get_dataset_size(): assert data_size == num_rows +def test_bucket_batch_invalid_column(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["invalid_column"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [3, 3, 2, 2] + element_length_function = (lambda x: x[0] % 4) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + + with pytest.raises(RuntimeError) as info: + num_rows = 0 + for _ in dataset.create_dict_iterator(num_epochs=1): + num_rows += 1 + + assert "BucketBatchByLength: Couldn't find the specified column in the dataset" in str(info.value) + + if __name__ == '__main__': test_bucket_batch_invalid_input() test_bucket_batch_multi_bucket_no_padding() @@ -413,4 +479,6 @@ if __name__ == '__main__': test_bucket_batch_drop_remainder() test_bucket_batch_default_length_function() test_bucket_batch_multi_column() + test_bucket_batch_three_columns() test_bucket_batch_get_dataset_size() + test_bucket_batch_invalid_column()