forked from mindspore-Ecosystem/mindspore
staging
This commit is contained in:
parent
2c5123300f
commit
77521e78d2
|
@ -305,6 +305,9 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
|
||||||
for (size_t i = 0; i < out_cols.size(); i++) {
|
for (size_t i = 0; i < out_cols.size(); i++) {
|
||||||
size_t col_id = column_name_id_map_[out_col_names_[i]];
|
size_t col_id = column_name_id_map_[out_col_names_[i]];
|
||||||
size_t row_id = 0;
|
size_t row_id = 0;
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(num_rows == out_cols[i].size(),
|
||||||
|
"column: " + out_col_names_[i] + " expects: " + std::to_string(num_rows) +
|
||||||
|
" rows returned from per_batch_map, gets: " + std::to_string(out_cols[i].size()));
|
||||||
for (auto &t_row : *out_q_table) {
|
for (auto &t_row : *out_q_table) {
|
||||||
t_row[col_id] = out_cols[i][row_id++];
|
t_row[col_id] = out_cols[i][row_id++];
|
||||||
}
|
}
|
||||||
|
@ -334,7 +337,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
|
||||||
// Acquire Python GIL
|
// Acquire Python GIL
|
||||||
py::gil_scoped_acquire gil_acquire;
|
py::gil_scoped_acquire gil_acquire;
|
||||||
if (Py_IsInitialized() == 0) {
|
if (Py_IsInitialized() == 0) {
|
||||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
py::object size = batch_size_func_(info);
|
py::object size = batch_size_func_(info);
|
||||||
|
@ -350,7 +353,7 @@ Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
|
||||||
"Invalid parameter, batch size function should return an integer greater than 0.");
|
"Invalid parameter, batch size function should return an integer greater than 0.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status(StatusCode::kOK, "Batch size func call succeed");
|
return Status(StatusCode::kOK, "Batch size func call succeed.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
|
Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info) {
|
||||||
|
@ -358,7 +361,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
||||||
// Acquire Python GIL
|
// Acquire Python GIL
|
||||||
py::gil_scoped_acquire gil_acquire;
|
py::gil_scoped_acquire gil_acquire;
|
||||||
if (Py_IsInitialized() == 0) {
|
if (Py_IsInitialized() == 0) {
|
||||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized.");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
// Prepare batch map call back parameters
|
// Prepare batch map call back parameters
|
||||||
|
@ -377,11 +380,24 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
|
||||||
py::object ret_py_obj = batch_map_func_(*input_args);
|
py::object ret_py_obj = batch_map_func_(*input_args);
|
||||||
// Parse batch map return value
|
// Parse batch map return value
|
||||||
py::tuple ret_tuple = py::cast<py::tuple>(ret_py_obj);
|
py::tuple ret_tuple = py::cast<py::tuple>(ret_py_obj);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance<py::tuple>(ret_tuple), "Batch map function should return a tuple");
|
CHECK_FAIL_RETURN_UNEXPECTED(py::isinstance<py::tuple>(ret_tuple), "Batch map function should return a tuple.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_tuple.size() == out_col_names_.size(), "Incorrect number of columns returned.");
|
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||||
|
ret_tuple.size() == out_col_names_.size(),
|
||||||
|
"Incorrect number of columns returned. expects: " + std::to_string(out_col_names_.size()) +
|
||||||
|
" gets: " + std::to_string(ret_tuple.size()));
|
||||||
for (size_t i = 0; i < ret_tuple.size(); i++) {
|
for (size_t i = 0; i < ret_tuple.size(); i++) {
|
||||||
TensorRow output_batch;
|
TensorRow output_batch;
|
||||||
|
// If user returns a type that is neither a list nor an array, issue a error msg.
|
||||||
|
if (py::isinstance<py::array>(ret_tuple[i])) {
|
||||||
|
MS_LOG(WARNING) << "column: " << out_col_names_[i]
|
||||||
|
<< " returned by per_batch_map is a np.array. Please use list instead.";
|
||||||
|
} else if (!py::isinstance<py::list>(ret_tuple[i])) {
|
||||||
|
MS_LOG(ERROR) << "column: " << out_col_names_[i]
|
||||||
|
<< " returned by per_batch_map is not a list, this could lead to conversion failure.";
|
||||||
|
}
|
||||||
|
|
||||||
py::list output_list = py::cast<py::list>(ret_tuple[i]);
|
py::list output_list = py::cast<py::list>(ret_tuple[i]);
|
||||||
|
|
||||||
for (size_t j = 0; j < output_list.size(); j++) {
|
for (size_t j = 0; j < output_list.size(); j++) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
|
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
|
||||||
|
|
|
@ -287,7 +287,9 @@ class Dataset:
|
||||||
per_batch_map (callable, optional): Per batch map callable. A callable which takes
|
per_batch_map (callable, optional): Per batch map callable. A callable which takes
|
||||||
(list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch
|
(list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch
|
||||||
of Tensors on a given column. The number of lists should match with number of entries in input_columns.
|
of Tensors on a given column. The number of lists should match with number of entries in input_columns.
|
||||||
The last parameter of the callable should always be a BatchInfo object.
|
The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return
|
||||||
|
(list[Tensor], list[Tensor], ...). The length of each list in output should be same as the input.
|
||||||
|
output_columns is required if the number of output lists is different from input.
|
||||||
input_columns (list[str], optional): List of names of the input columns. The size of the list should
|
input_columns (list[str], optional): List of names of the input columns. The size of the list should
|
||||||
match with signature of per_batch_map callable.
|
match with signature of per_batch_map callable.
|
||||||
output_columns (list[str], optional): List of names assigned to the columns
|
output_columns (list[str], optional): List of names assigned to the columns
|
||||||
|
@ -1462,7 +1464,7 @@ class Dataset:
|
||||||
data (Any): The data passed to the callback, user defined (default=None).
|
data (Any): The data passed to the callback, user defined (default=None).
|
||||||
"""
|
"""
|
||||||
if (not isinstance(num_batch, int) and num_batch is not None) or \
|
if (not isinstance(num_batch, int) and num_batch is not None) or \
|
||||||
(isinstance(num_batch, int) and num_batch <= 0):
|
(isinstance(num_batch, int) and num_batch <= 0):
|
||||||
# throwing exception, disable all sync_wait in pipeline
|
# throwing exception, disable all sync_wait in pipeline
|
||||||
self.disable_sync()
|
self.disable_sync()
|
||||||
raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch))
|
raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch))
|
||||||
|
|
|
@ -382,9 +382,17 @@ def test_exceptions_2():
|
||||||
def simple_copy(colList, batchInfo):
|
def simple_copy(colList, batchInfo):
|
||||||
return ([np.copy(arr) for arr in colList],)
|
return ([np.copy(arr) for arr in colList],)
|
||||||
|
|
||||||
def test_wrong_col_name(gen_num, batch_size):
|
def concat_copy(colList, batchInfo):
|
||||||
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"],
|
# this will duplicate the number of rows returned, which would be wrong!
|
||||||
per_batch_map=simple_copy)
|
return ([np.copy(arr) for arr in colList] * 2,)
|
||||||
|
|
||||||
|
def shrink_copy(colList, batchInfo):
|
||||||
|
# this will duplicate the number of rows returned, which would be wrong!
|
||||||
|
return ([np.copy(arr) for arr in colList][0:int(len(colList) / 2)],)
|
||||||
|
|
||||||
|
def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map):
|
||||||
|
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols,
|
||||||
|
per_batch_map=per_batch_map)
|
||||||
try:
|
try:
|
||||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
pass
|
pass
|
||||||
|
@ -393,7 +401,9 @@ def test_exceptions_2():
|
||||||
return str(e)
|
return str(e)
|
||||||
|
|
||||||
# test exception where column name is incorrect
|
# test exception where column name is incorrect
|
||||||
assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2)
|
assert "error. col:num1 doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy)
|
||||||
|
assert "expects: 2 rows returned from per_batch_map, gets: 4" in test_exceptions_config(4, 2, ["num"], concat_copy)
|
||||||
|
assert "expects: 4 rows returned from per_batch_map, gets: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue