per_batch_map input column optimization

This commit is contained in:
liu-yongqi-63 2022-06-08 11:16:34 +08:00
parent fd50ac5dbf
commit 0dfd42153f
5 changed files with 91 additions and 9 deletions

View File

@ -249,7 +249,7 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
#ifdef ENABLE_PYTHON
if (!in_col_names_.empty()) {
if (batch_map_func_) {
RETURN_IF_NOT_OK(MapColumns(&table_pair));
} // pass it through pyfun
#endif
@ -513,6 +513,27 @@ Status BatchOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(!(child_[0]->column_name_id_map().empty()),
"Invalid data, the column of the previous operator of the batch cannot be empty.");
// when per_batch_map is set and input_columns is not specified: enter the if branch of ENABLE_PYTHON
// when per_batch_map is set and input_columns is specified: will not enter the if branch of ENABLE_PYTHON
// and in_col_names_.empty()
// when per_batch_map is not set and input_columns is not specified: enter the if branch of in_col_names_.empty()
// when per_batch_map is not set and input_columns is specified: ERROR
#ifdef ENABLE_PYTHON
// when per_batch_map is set and input_columns is not specified, input_columns will be automatically speculated
if (batch_map_func_ && in_col_names_.empty()) {
auto column_name = child_[0]->column_name_id_map();
std::vector<std::pair<std::string, int32_t>> tmp;
std::copy(column_name.begin(), column_name.end(), std::back_inserter(tmp));
std::sort(tmp.begin(), tmp.end(),
[=](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second;
});
for (auto &it : tmp) {
in_col_names_.emplace_back(it.first);
}
}
#endif
if (in_col_names_.empty()) { // if per_batch_map is not set, do not need to deal with out_col_names
column_name_id_map_ = child_[0]->column_name_id_map();
return Status::OK();

View File

@ -282,7 +282,7 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
int32_t start_batch_size_;
const bool drop_; // bool for whether to drop remainder or not
const bool pad_; // bool for whether to perform padding on tensor
const std::vector<std::string> in_col_names_; // input column name for per_batch_map
std::vector<std::string> in_col_names_; // input column name for per_batch_map
std::vector<std::string> out_col_names_; // output column name for per_batch_map
PadInfo pad_info_; // column names to perform padding on
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1

View File

@ -86,8 +86,8 @@ Status BatchNode::ValidateParams() {
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (batch_map_func_ && in_col_names_.empty()) {
std::string err_msg = "Batch: 'in_col_names' cannot be empty when per_batch_map is used.";
if (!in_col_names_.empty() && !batch_map_func_) {
std::string err_msg = "Batch: per_batch_map needs to be specified when input_columns is set.";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
#endif

View File

@ -1258,9 +1258,9 @@ def check_batch(method):
for k, v in param_dict.get('pad_info').items():
check_pad_info(k, v)
if (per_batch_map is None) != (input_columns is None):
# These two parameters appear together.
raise ValueError("per_batch_map and input_columns need to be passed in together.")
if (input_columns is not None) and (per_batch_map is None):
# input_columns must be None when per_batch_map is not set
raise ValueError("input_columns can be specified only when per_batch_map is set.")
if input_columns is not None:
check_columns(input_columns, "input_columns")

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check_dict
@ -531,8 +533,7 @@ def test_batch_exception_14():
try:
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
except ValueError as e:
assert "per_batch_map and input_columns need to be passed in together." in str(
e)
assert "input_columns can be specified only when per_batch_map is set." in str(e)
def test_batch_exception_15():
@ -553,6 +554,64 @@ def test_batch_exception_15():
assert "batch_size is not within the required interval of [1, 2147483647]" in err_msg
def test_no_input_columns_01():
"""
Feature: Batch op
Description: Test with per_batch_map has value but input_columns has no value
Expectation: Output is equal to the expected output
"""
def gen_2_cols(num):
for i in range(1, 1 + num):
yield (np.array([i]), np.array([i ** 2]))
def swap_col(col1, col2, batch_info):
return ([np.copy(a) for a in col2], [np.copy(b) for b in col1])
def batch_map_config(num, s, f, col_order=None):
try:
dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"])
dst = dst.batch(batch_size=s, per_batch_map=f, column_order=col_order)
res = []
for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(row)
return res
except (ValueError, RuntimeError, TypeError) as e:
return str(e)
res = batch_map_config(3, 3, swap_col)[0]
assert np.array_equal(res["col1"], [[1], [4], [9]]) and np.array_equal(res["col2"], [[1], [2], [3]])
def test_no_input_columns_02():
"""
Feature: Batch op
Description: Test per_batch_map has value but input_columns has no value and given output_columns parameter
Expectation: Output is equal to the expected output
"""
def gen_2_cols(num):
for i in range(1, 1 + num):
yield (np.array([i]), np.array([i ** 2]))
def split_col(col1, col2, batch_info):
return (col1, [np.copy(arr) for arr in col2], [np.copy(-arr) for arr in col2])
def batch_map_config(num, s, f, out_nms, col_order=None):
try:
dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"])
dst = dst.batch(batch_size=s, per_batch_map=f, output_columns=out_nms, column_order=col_order)
res = []
for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
res.append(row)
return res
except (ValueError, RuntimeError, TypeError) as e:
return str(e)
# split 2 col into 3 cols
res = batch_map_config(3, 3, split_col, ["col1", "col_x2", "col_y2"])[0]
assert np.array_equal(res["col1"], [[1], [2], [3]])
assert np.array_equal(res["col_x2"], [[1], [4], [9]]) and np.array_equal(res["col_y2"], [[-1], [-4], [-9]])
if __name__ == '__main__':
test_batch_01()
test_batch_02()
@ -581,4 +640,6 @@ if __name__ == '__main__':
test_batch_exception_13()
test_batch_exception_14()
test_batch_exception_15()
test_no_input_columns_01()
test_no_input_columns_02()
logger.info('\n')