add concat batch function
This commit is contained in:
parent
ce82191fbe
commit
90e58867ec
|
@ -164,7 +164,8 @@ void BatchOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size) {
|
||||
Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size,
|
||||
bool concat_batch) {
|
||||
RETURN_UNEXPECTED_IF_NULL(src);
|
||||
RETURN_UNEXPECTED_IF_NULL(dest);
|
||||
if ((*src)->size() != batch_size) {
|
||||
|
@ -176,7 +177,10 @@ Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *d
|
|||
(*src)->pop_front();
|
||||
|
||||
for (const auto &tensor : (*dest)) {
|
||||
RETURN_IF_NOT_OK(tensor->ExpandDim(0));
|
||||
// If concat batch rows, the result should not be expend dimension.
|
||||
if (!concat_batch) {
|
||||
RETURN_IF_NOT_OK(tensor->ExpandDim(0));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -256,7 +260,7 @@ Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchIn
|
|||
if (pad_) {
|
||||
RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_));
|
||||
} // do padding if needed
|
||||
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size()));
|
||||
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, new_row, table_pair.first->size(), concat_batch_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -273,7 +277,6 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
|
|||
RETURN_UNEXPECTED_IF_NULL(table_pair->first);
|
||||
std::unique_ptr<TensorQTable> in_q_table = std::move(table_pair->first);
|
||||
size_t num_rows = in_q_table->size();
|
||||
auto out_q_table = std::make_unique<TensorQTable>(num_rows, TensorRow(column_name_id_map_.size(), nullptr));
|
||||
TensorTable in_cols(in_col_names_.size(), TensorRow(num_rows, nullptr)), out_cols;
|
||||
|
||||
std::unordered_map<std::string, size_t> in_col_name_id; // name of columns that need to be fed to per-batch_map
|
||||
|
@ -285,21 +288,33 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
|
|||
for (size_t i = 0; i < num_rows; i++) {
|
||||
in_cols[col_itr->second][i] = std::move((*in_q_table)[i][itr.second]);
|
||||
}
|
||||
} else { // col needs to be placed into the out table
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(InvokeBatchMapFunc(&in_cols, &out_cols, table_pair->second));
|
||||
|
||||
auto out_q_table = std::make_unique<TensorQTable>(num_rows, TensorRow(column_name_id_map_.size(), nullptr));
|
||||
// If concat batch rows, the num_rows should be 1.
|
||||
if (concat_batch_) {
|
||||
out_q_table = std::make_unique<TensorQTable>(1, TensorRow(column_name_id_map_.size(), nullptr));
|
||||
}
|
||||
|
||||
for (const auto &itr : child_map_) {
|
||||
auto col_itr = in_col_name_id.find(itr.first);
|
||||
if (col_itr == in_col_name_id.end()) { // col needs to be prepared for per_batch_map
|
||||
// col needs to be placed into the out table
|
||||
size_t col_id = column_name_id_map_[itr.first];
|
||||
for (size_t i = 0; i < num_rows; i++) {
|
||||
(*out_q_table)[i][col_id] = std::move((*in_q_table)[i][itr.second]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in_q_table.reset(); // release the input table
|
||||
RETURN_IF_NOT_OK(InvokeBatchMapFunc(&in_cols, &out_cols, table_pair->second));
|
||||
|
||||
for (size_t i = 0; i < out_cols.size(); i++) {
|
||||
size_t col_id = column_name_id_map_[out_col_names_[i]];
|
||||
size_t row_id = 0;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size(),
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size() || concat_batch_,
|
||||
"Invalid data, column: " + out_col_names_[i] +
|
||||
" expects: " + std::to_string(num_rows) +
|
||||
" rows returned from 'per_batch_map', got: " + std::to_string(out_cols[i].size()));
|
||||
|
@ -398,12 +413,19 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
|||
<< " returned by per_batch_map is not a list, this could lead to conversion failure.";
|
||||
}
|
||||
|
||||
py::list output_list = py::cast<py::list>(ret_tuple[i]);
|
||||
|
||||
for (size_t j = 0; j < output_list.size(); j++) {
|
||||
if (py::isinstance<py::array>(ret_tuple[i])) {
|
||||
concat_batch_ = true;
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
|
||||
// If concat batch rows, the batch map function result should be in 1 row.
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(ret_tuple[i]), &out));
|
||||
output_batch.push_back(std::move(out));
|
||||
} else {
|
||||
py::list output_list = py::cast<py::list>(ret_tuple[i]);
|
||||
for (size_t j = 0; j < output_list.size(); j++) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
|
||||
output_batch.push_back(std::move(out));
|
||||
}
|
||||
}
|
||||
output->push_back(std::move(output_batch));
|
||||
}
|
||||
|
|
|
@ -199,7 +199,8 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
|
|||
// @param int32_t size - batch_size
|
||||
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
|
||||
// @return Status The status code returned
|
||||
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size);
|
||||
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size,
|
||||
bool concat_batch = false);
|
||||
|
||||
// @param table
|
||||
// @param const PadInfo &pad_info pad info
|
||||
|
@ -280,6 +281,7 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
|
|||
Status ComputeColMap() override;
|
||||
|
||||
int32_t start_batch_size_;
|
||||
bool concat_batch_ = false; // bool for whether to concat batch rows
|
||||
const bool drop_; // bool for whether to drop remainder or not
|
||||
const bool pad_; // bool for whether to perform padding on tensor
|
||||
std::vector<std::string> in_col_names_; // input column name for per_batch_map
|
||||
|
|
|
@ -463,6 +463,35 @@ def test_multi_col_map():
|
|||
in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])
|
||||
|
||||
|
||||
def test_multi_col_concat_map():
|
||||
"""
|
||||
Feature: Batch op
|
||||
Description: Test Batch op with multiple columns with concat per_batch_map args with valid inputs
|
||||
Expectation: Output is equal to the expected output for valid input
|
||||
"""
|
||||
def gen_2_cols(num):
|
||||
for i in range(1, 1 + num):
|
||||
yield np.array([i]), np.array([i ** 2])
|
||||
|
||||
def concat_col(col1, col2, batch_info):
|
||||
arg_list = []
|
||||
for arg in [col1, col2]:
|
||||
rows = []
|
||||
for value in arg:
|
||||
rows.append(value)
|
||||
arg_list.append(np.array(np.concatenate(rows, axis=0)))
|
||||
return tuple(arg_list)
|
||||
|
||||
dst = ds.GeneratorDataset((lambda: gen_2_cols(3)), ["col1", "col2"])
|
||||
dst = dst.batch(batch_size=3, input_columns=["col1", "col2"], output_columns=["col1", "col2"],
|
||||
per_batch_map=concat_col)
|
||||
res = []
|
||||
for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
res.append(row)
|
||||
|
||||
assert np.array_equal(res[0]["col1"], [1, 2, 3]) and np.array_equal(res[0]["col2"], [1, 4, 9])
|
||||
|
||||
|
||||
def test_exceptions_2():
|
||||
"""
|
||||
Feature: Batch op
|
||||
|
|
Loading…
Reference in New Issue