This commit is contained in:
Zirui Wu 2020-12-10 12:19:44 -05:00
parent 2c5123300f
commit 77521e78d2
3 changed files with 39 additions and 11 deletions

View File

@ -305,6 +305,9 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
for (size_t i = 0; i < out_cols.size(); i++) {
size_t col_id = column_name_id_map_[out_col_names_[i]];
size_t row_id = 0;
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size(),
"column: " + out_col_names_[i] + " expects: " + std::to_string(num_rows) +
" rows returned from per_batch_map, gets: " + std::to_string(out_cols[i].size()));
for (auto &t_row : *out_q_table) {
t_row[col_id] = out_cols[i][row_id++];
}
@ -334,7 +337,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
}
try {
py::object size = batch_size_func_(info);
@ -350,7 +353,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
"Invalid parameter, batch size function should return an integer greater than 0.");
}
}
return Status(StatusCode::kOK, "Batch size func call succeed");
return Status(StatusCode::kOK, "Batch size func call succeed.");
}
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
@ -358,7 +361,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
}
try {
// Prepare batch map call back parameters
@ -377,11 +380,24 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
py::object ret_py_obj = batch_map_func_(*input_args);
// Parse batch map return value
py::tuple ret_tuple = py::cast<py::tuple>(ret_py_obj);
CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance<py::tuple>(ret_tuple), "Batch map function should return a tuple");
CHECK_FAIL_RETURN_UNEXPECTED(ret_tuple.size() == out_col_names_.size(), "Incorrect number of columns returned.");
CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance<py::tuple>(ret_tuple), "Batch map function should return a tuple.");
CHECK_FAIL_RETURN_UNEXPECTED(
ret_tuple.size() == out_col_names_.size(),
"Incorrect number of columns returned. expects: " + std::to_string(out_col_names_.size()) +
" gets: " + std::to_string(ret_tuple.size()));
for (size_t i = 0; i < ret_tuple.size(); i++) {
TensorRow output_batch;
// If user returns a type that is neither a list nor an array, issue a error msg.
if (py::isinstance<py::array>(ret_tuple[i])) {
MS_LOG(WARNING) << "column: " << out_col_names_[i]
<< " returned by per_batch_map is a np.array. Please use list instead.";
} else if (!py::isinstance<py::list>(ret_tuple[i])) {
MS_LOG(ERROR) << "column: " << out_col_names_[i]
<< " returned by per_batch_map is not a list, this could lead to conversion failure.";
}
py::list output_list = py::cast<py::list>(ret_tuple[i]);
for (size_t j = 0; j < output_list.size(); j++) {
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));

View File

@ -287,7 +287,9 @@ class Dataset:
per_batch_map (callable, optional): Per batch map callable. A callable which takes
(list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch
of Tensors on a given column. The number of lists should match with number of entries in input_columns.
The last parameter of the callable should always be a BatchInfo object.
The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return
(list[Tensor], list[Tensor], ...). The length of each list in output should be same as the input.
output_columns is required if the number of output lists is different from input.
input_columns (list[str], optional): List of names of the input columns. The size of the list should
match with signature of per_batch_map callable.
output_columns (list[str], optional): List of names assigned to the columns
@ -1462,7 +1464,7 @@ class Dataset:
data (Any): The data passed to the callback, user defined (default=None).
"""
if (not isinstance(num_batch, int) and num_batch is not None) or \
(isinstance(num_batch, int) and num_batch <= 0):
(isinstance(num_batch, int) and num_batch <= 0):
# throwing exception, disable all sync_wait in pipeline
self.disable_sync()
raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch))

View File

@ -382,9 +382,17 @@ def test_exceptions_2():
def simple_copy(colList, batchInfo):
return ([np.copy(arr) for arr in colList],)
def test_wrong_col_name(gen_num, batch_size):
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"],
per_batch_map=simple_copy)
def concat_copy(colList, batchInfo):
# this will duplicate the number of rows returned, which would be wrong!
return ([np.copy(arr) for arr in colList] * 2,)
def shrink_copy(colList, batchInfo):
# this will duplicate the number of rows returned, which would be wrong!
return ([np.copy(arr) for arr in colList][0:int(len(colList) / 2)],)
def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map):
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols,
per_batch_map=per_batch_map)
try:
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
@ -393,7 +401,9 @@ def test_exceptions_2():
return str(e)
# test exception where column name is incorrect
assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2)
assert "error. col:num1 doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy)
assert "expects: 2 rows returned from per_batch_map, gets: 4" in test_exceptions_config(4, 2, ["num"], concat_copy)
assert "expects: 4 rows returned from per_batch_map, gets: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy)
if __name__ == '__main__':