diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 0f1f8c7ad07..5ec1270fd10 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -305,6 +305,9 @@ Status BatchOp::MapColumns(std::pair, CBatchInfo> for (size_t i = 0; i < out_cols.size(); i++) { size_t col_id = column_name_id_map_[out_col_names_[i]]; size_t row_id = 0; + CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size(), + "column: " + out_col_names_[i] + " expects: " + std::to_string(num_rows) + + " rows returned from per_batch_map, gets: " + std::to_string(out_cols[i].size())); for (auto &t_row : *out_q_table) { t_row[col_id] = out_cols[i][row_id++]; } @@ -334,7 +337,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) { // Acquire Python GIL py::gil_scoped_acquire gil_acquire; if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized."); } try { py::object size = batch_size_func_(info); @@ -350,7 +353,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) { "Invalid parameter, batch size function should return an integer greater than 0."); } } - return Status(StatusCode::kOK, "Batch size func call succeed"); + return Status(StatusCode::kOK, "Batch size func call succeed."); } Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) { @@ -358,7 +361,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat // Acquire Python GIL py::gil_scoped_acquire gil_acquire; if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized."); } try { // Prepare batch map call back parameters @@ -377,11 +380,24 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat py::object ret_py_obj = batch_map_func_(*input_args); // Parse batch map return value py::tuple ret_tuple = py::cast(ret_py_obj); - CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance(ret_tuple), "Batch map function should return a tuple"); - CHECK_FAIL_RETURN_UNEXPECTED(ret_tuple.size() == out_col_names_.size(), "Incorrect number of columns returned."); + CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance(ret_tuple), "Batch map function should return a tuple."); + CHECK_FAIL_RETURN_UNEXPECTED( + ret_tuple.size() == out_col_names_.size(), + "Incorrect number of columns returned. expects: " + std::to_string(out_col_names_.size()) + + " gets: " + std::to_string(ret_tuple.size())); for (size_t i = 0; i < ret_tuple.size(); i++) { TensorRow output_batch; + // If user returns a type that is neither a list nor an array, issue a error msg. + if (py::isinstance(ret_tuple[i])) { + MS_LOG(WARNING) << "column: " << out_col_names_[i] + << " returned by per_batch_map is a np.array. Please use list instead."; + } else if (!py::isinstance(ret_tuple[i])) { + MS_LOG(ERROR) << "column: " << out_col_names_[i] + << " returned by per_batch_map is not a list, this could lead to conversion failure."; + } + py::list output_list = py::cast(ret_tuple[i]); + for (size_t j = 0; j < output_list.size(); j++) { std::shared_ptr out; RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast(output_list[j]), &out)); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 92bc8111681..bb351e598a3 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -287,7 +287,9 @@ class Dataset: per_batch_map (callable, optional): Per batch map callable. A callable which takes (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of Tensors on a given column. The number of lists should match with number of entries in input_columns. - The last parameter of the callable should always be a BatchInfo object. + The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return + (list[Tensor], list[Tensor], ...). The length of each list in output should be same as the input. + output_columns is required if the number of output lists is different from input. input_columns (list[str], optional): List of names of the input columns. The size of the list should match with signature of per_batch_map callable. output_columns (list[str], optional): List of names assigned to the columns @@ -1462,7 +1464,7 @@ class Dataset: data (Any): The data passed to the callback, user defined (default=None). """ if (not isinstance(num_batch, int) and num_batch is not None) or \ - (isinstance(num_batch, int) and num_batch <= 0): + (isinstance(num_batch, int) and num_batch <= 0): # throwing exception, disable all sync_wait in pipeline self.disable_sync() raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch)) diff --git a/tests/ut/python/dataset/test_var_batch_map.py b/tests/ut/python/dataset/test_var_batch_map.py index e0b8e34aef9..0a26b0621a7 100644 --- a/tests/ut/python/dataset/test_var_batch_map.py +++ b/tests/ut/python/dataset/test_var_batch_map.py @@ -382,9 +382,17 @@ def test_exceptions_2(): def simple_copy(colList, batchInfo): return ([np.copy(arr) for arr in colList],) - def test_wrong_col_name(gen_num, batch_size): - data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"], - per_batch_map=simple_copy) + def concat_copy(colList, batchInfo): + # this will duplicate the number of rows returned, which would be wrong! + return ([np.copy(arr) for arr in colList] * 2,) + + def shrink_copy(colList, batchInfo): + # this will duplicate the number of rows returned, which would be wrong! + return ([np.copy(arr) for arr in colList][0:int(len(colList) / 2)],) + + def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols, + per_batch_map=per_batch_map) try: for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): pass @@ -393,7 +401,9 @@ def test_exceptions_2(): return str(e) # test exception where column name is incorrect - assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2) + assert "error. col:num1 doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy) + assert "expects: 2 rows returned from per_batch_map, gets: 4" in test_exceptions_config(4, 2, ["num"], concat_copy) + assert "expects: 4 rows returned from per_batch_map, gets: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy) if __name__ == '__main__':