forked from mindspore-Ecosystem/mindspore
staging
This commit is contained in:
parent
2c5123300f
commit
77521e78d2
|
@ -305,6 +305,9 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
|
|||
for (size_t i = 0; i < out_cols.size(); i++) {
|
||||
size_t col_id = column_name_id_map_[out_col_names_[i]];
|
||||
size_t row_id = 0;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size(),
|
||||
"column: " + out_col_names_[i] + " expects: " + std::to_string(num_rows) +
|
||||
" rows returned from per_batch_map, gets: " + std::to_string(out_cols[i].size()));
|
||||
for (auto &t_row : *out_q_table) {
|
||||
t_row[col_id] = out_cols[i][row_id++];
|
||||
}
|
||||
|
@ -334,7 +337,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
|
|||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||
}
|
||||
try {
|
||||
py::object size = batch_size_func_(info);
|
||||
|
@ -350,7 +353,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
|
|||
"Invalid parameter, batch size function should return an integer greater than 0.");
|
||||
}
|
||||
}
|
||||
return Status(StatusCode::kOK, "Batch size func call succeed");
|
||||
return Status(StatusCode::kOK, "Batch size func call succeed.");
|
||||
}
|
||||
|
||||
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
|
||||
|
@ -358,7 +361,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
|||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||
}
|
||||
try {
|
||||
// Prepare batch map call back parameters
|
||||
|
@ -377,11 +380,24 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
|||
py::object ret_py_obj = batch_map_func_(*input_args);
|
||||
// Parse batch map return value
|
||||
py::tuple ret_tuple = py::cast<py::tuple>(ret_py_obj);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance<py::tuple>(ret_tuple), "Batch map function should return a tuple");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_tuple.size() == out_col_names_.size(), "Incorrect number of columns returned.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance<py::tuple>(ret_tuple), "Batch map function should return a tuple.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
ret_tuple.size() == out_col_names_.size(),
|
||||
"Incorrect number of columns returned. expects: " + std::to_string(out_col_names_.size()) +
|
||||
" gets: " + std::to_string(ret_tuple.size()));
|
||||
for (size_t i = 0; i < ret_tuple.size(); i++) {
|
||||
TensorRow output_batch;
|
||||
// If user returns a type that is neither a list nor an array, issue a error msg.
|
||||
if (py::isinstance<py::array>(ret_tuple[i])) {
|
||||
MS_LOG(WARNING) << "column: " << out_col_names_[i]
|
||||
<< " returned by per_batch_map is a np.array. Please use list instead.";
|
||||
} else if (!py::isinstance<py::list>(ret_tuple[i])) {
|
||||
MS_LOG(ERROR) << "column: " << out_col_names_[i]
|
||||
<< " returned by per_batch_map is not a list, this could lead to conversion failure.";
|
||||
}
|
||||
|
||||
py::list output_list = py::cast<py::list>(ret_tuple[i]);
|
||||
|
||||
for (size_t j = 0; j < output_list.size(); j++) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
|
||||
|
|
|
@ -287,7 +287,9 @@ class Dataset:
|
|||
per_batch_map (callable, optional): Per batch map callable. A callable which takes
|
||||
(list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch
|
||||
of Tensors on a given column. The number of lists should match with number of entries in input_columns.
|
||||
The last parameter of the callable should always be a BatchInfo object.
|
||||
The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return
|
||||
(list[Tensor], list[Tensor], ...). The length of each list in output should be same as the input.
|
||||
output_columns is required if the number of output lists is different from input.
|
||||
input_columns (list[str], optional): List of names of the input columns. The size of the list should
|
||||
match with signature of per_batch_map callable.
|
||||
output_columns (list[str], optional): List of names assigned to the columns
|
||||
|
@ -1462,7 +1464,7 @@ class Dataset:
|
|||
data (Any): The data passed to the callback, user defined (default=None).
|
||||
"""
|
||||
if (not isinstance(num_batch, int) and num_batch is not None) or \
|
||||
(isinstance(num_batch, int) and num_batch <= 0):
|
||||
(isinstance(num_batch, int) and num_batch <= 0):
|
||||
# throwing exception, disable all sync_wait in pipeline
|
||||
self.disable_sync()
|
||||
raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch))
|
||||
|
|
|
@ -382,9 +382,17 @@ def test_exceptions_2():
|
|||
def simple_copy(colList, batchInfo):
|
||||
return ([np.copy(arr) for arr in colList],)
|
||||
|
||||
def test_wrong_col_name(gen_num, batch_size):
|
||||
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"],
|
||||
per_batch_map=simple_copy)
|
||||
def concat_copy(colList, batchInfo):
|
||||
# this will duplicate the number of rows returned, which would be wrong!
|
||||
return ([np.copy(arr) for arr in colList] * 2,)
|
||||
|
||||
def shrink_copy(colList, batchInfo):
|
||||
# this will duplicate the number of rows returned, which would be wrong!
|
||||
return ([np.copy(arr) for arr in colList][0:int(len(colList) / 2)],)
|
||||
|
||||
def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map):
|
||||
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols,
|
||||
per_batch_map=per_batch_map)
|
||||
try:
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
|
@ -393,7 +401,9 @@ def test_exceptions_2():
|
|||
return str(e)
|
||||
|
||||
# test exception where column name is incorrect
|
||||
assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2)
|
||||
assert "error. col:num1 doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy)
|
||||
assert "expects: 2 rows returned from per_batch_map, gets: 4" in test_exceptions_config(4, 2, ["num"], concat_copy)
|
||||
assert "expects: 4 rows returned from per_batch_map, gets: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue