Return -1 as unknown batch size for callable batch_size param

This commit is contained in:
luoyang 2022-05-11 19:29:29 +08:00
parent ca533b0ab8
commit 3cb854abc7
2 changed files with 27 additions and 4 deletions

View File

@ -786,7 +786,7 @@ Status TreeGetters::GetBatchSize(int64_t *batch_size) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
RETURN_UNEXPECTED_IF_NULL(root);
*batch_size = root->GetTreeBatchSize();
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != 0, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
return Status::OK();
}

View File

@ -171,6 +171,26 @@ def test_variable_size_batch():
4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n"
def test_get_batchsize_on_callable_batchsize():
"""
Feature: Test get_batch_size() func which can get batch size of pipeline.
Description: Use get_batch_size() func on dataset pipeline with dynamic batch size.
Expectation: Return -1 which means the batch size is unknown.
"""
def gen(num):
for i in range(num):
yield (np.array([i]),)
def add_one_by_batch_num(batch_info):
return batch_info.get_batch_num() + 1
data1 = ds.GeneratorDataset((lambda: gen(10)), ["num"])
data1 = data1.batch(batch_size=add_one_by_batch_num, drop_remainder=True)
batchsize = data1.get_batch_size()
assert batchsize == -1
def test_basic_batch_map():
def check_res(arr1, arr2):
for ind, _ in enumerate(arr1):
@ -275,9 +295,9 @@ def test_var_batch_multi_col_map():
return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1)
# multiply first col by batch_num, multiply second col by -batch_num
def map_func(col1, col2, batchInfo):
return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1],
[np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2])
def map_func(col1, col2, batch_info):
return ([np.copy((1 + batch_info.get_batch_num()) * arr) for arr in col1],
[np.copy(-(1 + batch_info.get_batch_num()) * arr) for arr in col2])
def batch_map_config(num, r, fbatch, fmap, col_names, res):
data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \
@ -445,6 +465,9 @@ if __name__ == '__main__':
logger.info("Running test_var_batch_map.py test_variable_size_batch() function")
test_variable_size_batch()
logger.info("Running test_var_batch_map.py test_get_batchsize_on_callable_batchsize() function")
test_get_batchsize_on_callable_batchsize()
logger.info("Running test_var_batch_map.py test_basic_batch_map() function")
test_basic_batch_map()