Return -1 as unknown batch size for callable batch_size param
This commit is contained in:
parent
ca533b0ab8
commit
3cb854abc7
|
@ -786,7 +786,7 @@ Status TreeGetters::GetBatchSize(int64_t *batch_size) {
|
||||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||||
RETURN_UNEXPECTED_IF_NULL(root);
|
RETURN_UNEXPECTED_IF_NULL(root);
|
||||||
*batch_size = root->GetTreeBatchSize();
|
*batch_size = root->GetTreeBatchSize();
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
|
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != 0, "GetBatchSize: Failed to find the batch size in Dataset pipeline.");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -171,6 +171,26 @@ def test_variable_size_batch():
|
||||||
4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n"
|
4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_batchsize_on_callable_batchsize():
|
||||||
|
"""
|
||||||
|
Feature: Test get_batch_size() func which can get batch size of pipeline.
|
||||||
|
Description: Use get_batch_size() func on dataset pipeline with dynamic batch size.
|
||||||
|
Expectation: Return -1 which means the batch size is unknown.
|
||||||
|
"""
|
||||||
|
def gen(num):
|
||||||
|
for i in range(num):
|
||||||
|
yield (np.array([i]),)
|
||||||
|
|
||||||
|
def add_one_by_batch_num(batch_info):
|
||||||
|
return batch_info.get_batch_num() + 1
|
||||||
|
|
||||||
|
data1 = ds.GeneratorDataset((lambda: gen(10)), ["num"])
|
||||||
|
data1 = data1.batch(batch_size=add_one_by_batch_num, drop_remainder=True)
|
||||||
|
|
||||||
|
batchsize = data1.get_batch_size()
|
||||||
|
assert batchsize == -1
|
||||||
|
|
||||||
|
|
||||||
def test_basic_batch_map():
|
def test_basic_batch_map():
|
||||||
def check_res(arr1, arr2):
|
def check_res(arr1, arr2):
|
||||||
for ind, _ in enumerate(arr1):
|
for ind, _ in enumerate(arr1):
|
||||||
|
@ -275,9 +295,9 @@ def test_var_batch_multi_col_map():
|
||||||
return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1)
|
return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1)
|
||||||
|
|
||||||
# multiply first col by batch_num, multiply second col by -batch_num
|
# multiply first col by batch_num, multiply second col by -batch_num
|
||||||
def map_func(col1, col2, batchInfo):
|
def map_func(col1, col2, batch_info):
|
||||||
return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1],
|
return ([np.copy((1 + batch_info.get_batch_num()) * arr) for arr in col1],
|
||||||
[np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2])
|
[np.copy(-(1 + batch_info.get_batch_num()) * arr) for arr in col2])
|
||||||
|
|
||||||
def batch_map_config(num, r, fbatch, fmap, col_names, res):
|
def batch_map_config(num, r, fbatch, fmap, col_names, res):
|
||||||
data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \
|
data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \
|
||||||
|
@ -445,6 +465,9 @@ if __name__ == '__main__':
|
||||||
logger.info("Running test_var_batch_map.py test_variable_size_batch() function")
|
logger.info("Running test_var_batch_map.py test_variable_size_batch() function")
|
||||||
test_variable_size_batch()
|
test_variable_size_batch()
|
||||||
|
|
||||||
|
logger.info("Running test_var_batch_map.py test_get_batchsize_on_callable_batchsize() function")
|
||||||
|
test_get_batchsize_on_callable_batchsize()
|
||||||
|
|
||||||
logger.info("Running test_var_batch_map.py test_basic_batch_map() function")
|
logger.info("Running test_var_batch_map.py test_basic_batch_map() function")
|
||||||
test_basic_batch_map()
|
test_basic_batch_map()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue