forked from mindspore-Ecosystem/mindspore
per_batch_map input column optimization
This commit is contained in:
parent
fd50ac5dbf
commit
0dfd42153f
|
@ -249,7 +249,7 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
|
|||
Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (!in_col_names_.empty()) {
|
||||
if (batch_map_func_) {
|
||||
RETURN_IF_NOT_OK(MapColumns(&table_pair));
|
||||
} // pass it through pyfun
|
||||
#endif
|
||||
|
@ -513,6 +513,27 @@ Status BatchOp::ComputeColMap() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(!(child_[0]->column_name_id_map().empty()),
|
||||
"Invalid data, the column of the previous operator of the batch cannot be empty.");
|
||||
|
||||
// when per_batch_map is set and input_columns is not specified: enter the if branch of ENABLE_PYTHON
|
||||
// when per_batch_map is set and input_columns is specified: will not enter the if branch of ENABLE_PYTHON
|
||||
// and in_col_names_.empty()
|
||||
// when per_batch_map is not set and input_columns is not specified: enter the if branch of in_col_names_.empty()
|
||||
// when per_batch_map is not set and input_columns is specified: ERROR
|
||||
#ifdef ENABLE_PYTHON
|
||||
// when per_batch_map is set and input_columns is not specified, input_columns will be automatically speculated
|
||||
if (batch_map_func_ && in_col_names_.empty()) {
|
||||
auto column_name = child_[0]->column_name_id_map();
|
||||
std::vector<std::pair<std::string, int32_t>> tmp;
|
||||
std::copy(column_name.begin(), column_name.end(), std::back_inserter(tmp));
|
||||
std::sort(tmp.begin(), tmp.end(),
|
||||
[=](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (auto &it : tmp) {
|
||||
in_col_names_.emplace_back(it.first);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (in_col_names_.empty()) { // if per_batch_map is not set, do not need to deal with out_col_names
|
||||
column_name_id_map_ = child_[0]->column_name_id_map();
|
||||
return Status::OK();
|
||||
|
|
|
@ -282,7 +282,7 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
|
|||
int32_t start_batch_size_;
|
||||
const bool drop_; // bool for whether to drop remainder or not
|
||||
const bool pad_; // bool for whether to perform padding on tensor
|
||||
const std::vector<std::string> in_col_names_; // input column name for per_batch_map
|
||||
std::vector<std::string> in_col_names_; // input column name for per_batch_map
|
||||
std::vector<std::string> out_col_names_; // output column name for per_batch_map
|
||||
PadInfo pad_info_; // column names to perform padding on
|
||||
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
|
||||
|
|
|
@ -86,8 +86,8 @@ Status BatchNode::ValidateParams() {
|
|||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (batch_map_func_ && in_col_names_.empty()) {
|
||||
std::string err_msg = "Batch: 'in_col_names' cannot be empty when per_batch_map is used.";
|
||||
if (!in_col_names_.empty() && !batch_map_func_) {
|
||||
std::string err_msg = "Batch: per_batch_map needs to be specified when input_columns is set.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1258,9 +1258,9 @@ def check_batch(method):
|
|||
for k, v in param_dict.get('pad_info').items():
|
||||
check_pad_info(k, v)
|
||||
|
||||
if (per_batch_map is None) != (input_columns is None):
|
||||
# These two parameters appear together.
|
||||
raise ValueError("per_batch_map and input_columns need to be passed in together.")
|
||||
if (input_columns is not None) and (per_batch_map is None):
|
||||
# input_columns must be None when per_batch_map is not set
|
||||
raise ValueError("input_columns can be specified only when per_batch_map is set.")
|
||||
|
||||
if input_columns is not None:
|
||||
check_columns(input_columns, "input_columns")
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_dict
|
||||
|
@ -531,8 +533,7 @@ def test_batch_exception_14():
|
|||
try:
|
||||
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
|
||||
except ValueError as e:
|
||||
assert "per_batch_map and input_columns need to be passed in together." in str(
|
||||
e)
|
||||
assert "input_columns can be specified only when per_batch_map is set." in str(e)
|
||||
|
||||
|
||||
def test_batch_exception_15():
|
||||
|
@ -553,6 +554,64 @@ def test_batch_exception_15():
|
|||
assert "batch_size is not within the required interval of [1, 2147483647]" in err_msg
|
||||
|
||||
|
||||
def test_no_input_columns_01():
|
||||
"""
|
||||
Feature: Batch op
|
||||
Description: Test with per_batch_map has value but input_columns has no value
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
def gen_2_cols(num):
|
||||
for i in range(1, 1 + num):
|
||||
yield (np.array([i]), np.array([i ** 2]))
|
||||
|
||||
def swap_col(col1, col2, batch_info):
|
||||
return ([np.copy(a) for a in col2], [np.copy(b) for b in col1])
|
||||
|
||||
def batch_map_config(num, s, f, col_order=None):
|
||||
try:
|
||||
dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"])
|
||||
dst = dst.batch(batch_size=s, per_batch_map=f, column_order=col_order)
|
||||
res = []
|
||||
for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
res.append(row)
|
||||
return res
|
||||
except (ValueError, RuntimeError, TypeError) as e:
|
||||
return str(e)
|
||||
|
||||
res = batch_map_config(3, 3, swap_col)[0]
|
||||
assert np.array_equal(res["col1"], [[1], [4], [9]]) and np.array_equal(res["col2"], [[1], [2], [3]])
|
||||
|
||||
|
||||
def test_no_input_columns_02():
|
||||
"""
|
||||
Feature: Batch op
|
||||
Description: Test per_batch_map has value but input_columns has no value and given output_columns parameter
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
def gen_2_cols(num):
|
||||
for i in range(1, 1 + num):
|
||||
yield (np.array([i]), np.array([i ** 2]))
|
||||
|
||||
def split_col(col1, col2, batch_info):
|
||||
return (col1, [np.copy(arr) for arr in col2], [np.copy(-arr) for arr in col2])
|
||||
|
||||
def batch_map_config(num, s, f, out_nms, col_order=None):
|
||||
try:
|
||||
dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"])
|
||||
dst = dst.batch(batch_size=s, per_batch_map=f, output_columns=out_nms, column_order=col_order)
|
||||
res = []
|
||||
for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
res.append(row)
|
||||
return res
|
||||
except (ValueError, RuntimeError, TypeError) as e:
|
||||
return str(e)
|
||||
|
||||
# split 2 col into 3 cols
|
||||
res = batch_map_config(3, 3, split_col, ["col1", "col_x2", "col_y2"])[0]
|
||||
assert np.array_equal(res["col1"], [[1], [2], [3]])
|
||||
assert np.array_equal(res["col_x2"], [[1], [4], [9]]) and np.array_equal(res["col_y2"], [[-1], [-4], [-9]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch_01()
|
||||
test_batch_02()
|
||||
|
@ -581,4 +640,6 @@ if __name__ == '__main__':
|
|||
test_batch_exception_13()
|
||||
test_batch_exception_14()
|
||||
test_batch_exception_15()
|
||||
test_no_input_columns_01()
|
||||
test_no_input_columns_02()
|
||||
logger.info('\n')
|
||||
|
|
Loading…
Reference in New Issue