Return -1 as unknown batch size for callable batch_size param
This commit is contained in:
parent
ca533b0ab8
commit
3cb854abc7
|
@ -786,7 +786,7 @@ Status TreeGetters::GetBatchSize(int64_t *batch_size) {
|
|||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
*batch_size = root->GetTreeBatchSize();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != 0, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -171,6 +171,26 @@ def test_variable_size_batch():
|
|||
4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n"
|
||||
|
||||
|
||||
def test_get_batchsize_on_callable_batchsize():
|
||||
"""
|
||||
Feature: Test get_batch_size() func which can get batch size of pipeline.
|
||||
Description: Use get_batch_size() func on dataset pipeline with dynamic batch size.
|
||||
Expectation: Return -1 which means the batch size is unknown.
|
||||
"""
|
||||
def gen(num):
|
||||
for i in range(num):
|
||||
yield (np.array([i]),)
|
||||
|
||||
def add_one_by_batch_num(batch_info):
|
||||
return batch_info.get_batch_num() + 1
|
||||
|
||||
data1 = ds.GeneratorDataset((lambda: gen(10)), ["num"])
|
||||
data1 = data1.batch(batch_size=add_one_by_batch_num, drop_remainder=True)
|
||||
|
||||
batchsize = data1.get_batch_size()
|
||||
assert batchsize == -1
|
||||
|
||||
|
||||
def test_basic_batch_map():
|
||||
def check_res(arr1, arr2):
|
||||
for ind, _ in enumerate(arr1):
|
||||
|
@ -275,9 +295,9 @@ def test_var_batch_multi_col_map():
|
|||
return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1)
|
||||
|
||||
# multiply first col by batch_num, multiply second col by -batch_num
|
||||
def map_func(col1, col2, batchInfo):
|
||||
return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1],
|
||||
[np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2])
|
||||
def map_func(col1, col2, batch_info):
|
||||
return ([np.copy((1 + batch_info.get_batch_num()) * arr) for arr in col1],
|
||||
[np.copy(-(1 + batch_info.get_batch_num()) * arr) for arr in col2])
|
||||
|
||||
def batch_map_config(num, r, fbatch, fmap, col_names, res):
|
||||
data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \
|
||||
|
@ -445,6 +465,9 @@ if __name__ == '__main__':
|
|||
logger.info("Running test_var_batch_map.py test_variable_size_batch() function")
|
||||
test_variable_size_batch()
|
||||
|
||||
logger.info("Running test_var_batch_map.py test_get_batchsize_on_callable_batchsize() function")
|
||||
test_get_batchsize_on_callable_batchsize()
|
||||
|
||||
logger.info("Running test_var_batch_map.py test_basic_batch_map() function")
|
||||
test_basic_batch_map()
|
||||
|
||||
|
|
Loading…
Reference in New Issue