diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index c134ea44fa6..34df321d00e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -786,7 +786,7 @@ Status TreeGetters::GetBatchSize(int64_t *batch_size) { std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); RETURN_UNEXPECTED_IF_NULL(root); *batch_size = root->GetTreeBatchSize(); - CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "GetBatchSize: Failed to find the batch size in Dataset pipeline."); + CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != 0, "GetBatchSize: Failed to find the batch size in Dataset pipeline."); return Status::OK(); } diff --git a/tests/ut/python/dataset/test_var_batch_map.py b/tests/ut/python/dataset/test_var_batch_map.py index 07e3dbd91b2..e55b4888c4e 100644 --- a/tests/ut/python/dataset/test_var_batch_map.py +++ b/tests/ut/python/dataset/test_var_batch_map.py @@ -171,6 +171,26 @@ def test_variable_size_batch(): 4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n" +def test_get_batchsize_on_callable_batchsize(): + """ + Feature: Test get_batch_size() func which can get batch size of pipeline. + Description: Use get_batch_size() func on dataset pipeline with dynamic batch size. + Expectation: Return -1 which means the batch size is unknown. + """ + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def add_one_by_batch_num(batch_info): + return batch_info.get_batch_num() + 1 + + data1 = ds.GeneratorDataset((lambda: gen(10)), ["num"]) + data1 = data1.batch(batch_size=add_one_by_batch_num, drop_remainder=True) + + batchsize = data1.get_batch_size() + assert batchsize == -1 + + def test_basic_batch_map(): def check_res(arr1, arr2): for ind, _ in enumerate(arr1): @@ -275,9 +295,9 @@ def test_var_batch_multi_col_map(): return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1) # multiply first col by batch_num, multiply second col by -batch_num - def map_func(col1, col2, batchInfo): - return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1], - [np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2]) + def map_func(col1, col2, batch_info): + return ([np.copy((1 + batch_info.get_batch_num()) * arr) for arr in col1], + [np.copy(-(1 + batch_info.get_batch_num()) * arr) for arr in col2]) def batch_map_config(num, r, fbatch, fmap, col_names, res): data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \ @@ -445,6 +465,9 @@ if __name__ == '__main__': logger.info("Running test_var_batch_map.py test_variable_size_batch() function") test_variable_size_batch() + logger.info("Running test_var_batch_map.py test_get_batchsize_on_callable_batchsize() function") + test_get_batchsize_on_callable_batchsize() + logger.info("Running test_var_batch_map.py test_basic_batch_map() function") test_basic_batch_map()