add concat batch function

This commit is contained in:
jiangshuqiang 2022-06-15 16:58:11 +08:00
parent ce82191fbe
commit 90e58867ec
3 changed files with 66 additions and 13 deletions

View File

@ -164,7 +164,8 @@ void BatchOp::Print(std::ostream &out, bool show_all) const {
}
}
Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size) {
Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size,
bool concat_batch) {
RETURN_UNEXPECTED_IF_NULL(src);
RETURN_UNEXPECTED_IF_NULL(dest);
if ((*src)->size() != batch_size) {
@ -176,7 +177,10 @@ Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *d
(*src)->pop_front();
for (const auto &tensor : (*dest)) {
RETURN_IF_NOT_OK(tensor->ExpandDim(0));
// If concat batch rows, the result should not be expend dimension.
if (!concat_batch) {
RETURN_IF_NOT_OK(tensor->ExpandDim(0));
}
}
return Status::OK();
}
@ -256,7 +260,7 @@ Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchIn
if (pad_) {
RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_));
} // do padding if needed
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size(), concat_batch_));
return Status::OK();
}
@ -273,7 +277,6 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
RETURN_UNEXPECTED_IF_NULL(table_pair->first);
std::unique_ptr<TensorQTable> in_q_table = std::move(table_pair->first);
size_t num_rows = in_q_table->size();
auto out_q_table = std::make_unique<TensorQTable>(num_rows, TensorRow(column_name_id_map_.size(), nullptr));
TensorTable in_cols(in_col_names_.size(), TensorRow(num_rows, nullptr)), out_cols;
std::unordered_map<std::string, size_t> in_col_name_id; // name of columns that need to be fed to per-batch_map
@ -285,21 +288,33 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
for (size_t i = 0; i < num_rows; i++) {
in_cols[col_itr->second][i] = std::move((*in_q_table)[i][itr.second]);
}
} else { // col needs to be placed into the out table
}
}
RETURN_IF_NOT_OK(InvokeBatchMapFunc(&in_cols, &out_cols, table_pair->second));
auto out_q_table = std::make_unique<TensorQTable>(num_rows, TensorRow(column_name_id_map_.size(), nullptr));
// If concat batch rows, the num_rows should be 1.
if (concat_batch_) {
out_q_table = std::make_unique<TensorQTable>(1, TensorRow(column_name_id_map_.size(), nullptr));
}
for (const auto &itr : child_map_) {
auto col_itr = in_col_name_id.find(itr.first);
if (col_itr == in_col_name_id.end()) { // col needs to be prepared for per_batch_map
// col needs to be placed into the out table
size_t col_id = column_name_id_map_[itr.first];
for (size_t i = 0; i < num_rows; i++) {
(*out_q_table)[i][col_id] = std::move((*in_q_table)[i][itr.second]);
}
}
}
in_q_table.reset(); // release the input table
RETURN_IF_NOT_OK(InvokeBatchMapFunc(&in_cols, &out_cols, table_pair->second));
for (size_t i = 0; i < out_cols.size(); i++) {
size_t col_id = column_name_id_map_[out_col_names_[i]];
size_t row_id = 0;
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size(),
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size() || concat_batch_,
"Invalid data, column: " + out_col_names_[i] +
" expects: " + std::to_string(num_rows) +
" rows returned from 'per_batch_map', got: " + std::to_string(out_cols[i].size()));
@ -398,12 +413,19 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
<< " returned by per_batch_map is not a list, this could lead to conversion failure.";
}
py::list output_list = py::cast<py::list>(ret_tuple[i]);
for (size_t j = 0; j < output_list.size(); j++) {
if (py::isinstance<py::array>(ret_tuple[i])) {
concat_batch_ = true;
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
// If concat batch rows, the batch map function result should be in 1 row.
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(ret_tuple[i]), &out));
output_batch.push_back(std::move(out));
} else {
py::list output_list = py::cast<py::list>(ret_tuple[i]);
for (size_t j = 0; j < output_list.size(); j++) {
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
output_batch.push_back(std::move(out));
}
}
output->push_back(std::move(output_batch));
}

View File

@ -199,7 +199,8 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
// @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status The status code returned
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size);
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size,
bool concat_batch = false);
// @param table
// @param const PadInfo &pad_info pad info
@ -280,6 +281,7 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
Status ComputeColMap() override;
int32_t start_batch_size_;
bool concat_batch_ = false; // bool for whether to concat batch rows
const bool drop_; // bool for whether to drop remainder or not
const bool pad_; // bool for whether to perform padding on tensor
std::vector<std::string> in_col_names_; // input column name for per_batch_map

View File

@ -463,6 +463,35 @@ def test_multi_col_map():
in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])
def test_multi_col_concat_map():
"""
Feature: Batch op
Description: Test Batch op with multiple columns with concat per_batch_map args with valid inputs
Expectation: Output is equal to the expected output for valid input
"""
def gen_2_cols(num):
for i in range(1, 1 + num):
yield np.array([i]), np.array([i ** 2])
def concat_col(col1, col2, batch_info):
arg_list = []
for arg in [col1, col2]:
rows = []
for value in arg:
rows.append(value)
arg_list.append(np.array(np.concatenate(rows, axis=0)))
return tuple(arg_list)
dst = ds.GeneratorDataset((lambda: gen_2_cols(3)), ["col1", "col2"])
dst = dst.batch(batch_size=3, input_columns=["col1", "col2"], output_columns=["col1", "col2"],
per_batch_map=concat_col)
res = []
for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(row)
assert np.array_equal(res[0]["col1"], [1, 2, 3]) and np.array_equal(res[0]["col2"], [1, 4, 9])
def test_exceptions_2():
"""
Feature: Batch op