forked from mindspore-Ecosystem/mindspore
parent
8fe3cf6991
commit
d54ba374b9
|
@ -16,10 +16,9 @@
|
|||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
|
@ -32,15 +31,15 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
@ -53,6 +52,7 @@
|
|||
#include "minddata/mindrecord/include/shard_writer.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -211,16 +211,17 @@ Status DEPipeline::GetColumnNames(py::list *output) {
|
|||
}
|
||||
|
||||
Status DEPipeline::GetNextAsMap(py::dict *output) {
|
||||
TensorMap row;
|
||||
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec;
|
||||
Status s;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
s = iterator_->GetNextAsMap(&row);
|
||||
s = iterator_->GetNextAsOrderedPair(&vec);
|
||||
}
|
||||
RETURN_IF_NOT_OK(s);
|
||||
// Generate Python dict as return
|
||||
for (auto el : row) {
|
||||
(*output)[common::SafeCStr(el.first)] = el.second;
|
||||
|
||||
// Generate Python dict, python dict maintains its insertion order
|
||||
for (const auto &pair : vec) {
|
||||
(*output)[common::SafeCStr(pair.first)] = pair.second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -614,7 +615,7 @@ Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string,
|
|||
}
|
||||
if (mr_shape.empty()) {
|
||||
if (mr_type == "bytes") { // map to int32 when bytes without shape.
|
||||
mr_type == "int32";
|
||||
mr_type = "int32";
|
||||
}
|
||||
(*schema)[column_name] = {{"type", mr_type}};
|
||||
} else {
|
||||
|
@ -905,7 +906,7 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
if (py::isinstance<py::int_>(args["batch_size"])) {
|
||||
batch_size_ = ToInt(args["batch_size"]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid.");
|
||||
builder = std::make_shared<BatchOp::Builder>(ToInt(args["batch_size"]));
|
||||
builder = std::make_shared<BatchOp::Builder>(batch_size_);
|
||||
} else if (py::isinstance<py::function>(args["batch_size"])) {
|
||||
builder = std::make_shared<BatchOp::Builder>(1);
|
||||
(void)builder->SetBatchSizeFunc(args["batch_size"].cast<py::function>());
|
||||
|
@ -920,17 +921,13 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
if (!value.is_none()) {
|
||||
if (key == "drop_remainder") {
|
||||
(void)builder->SetDrop(ToBool(value));
|
||||
}
|
||||
if (key == "num_parallel_workers") {
|
||||
} else if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
}
|
||||
if (key == "per_batch_map") {
|
||||
} else if (key == "per_batch_map") {
|
||||
(void)builder->SetBatchMapFunc(value.cast<py::function>());
|
||||
}
|
||||
if (key == "input_columns") {
|
||||
} else if (key == "input_columns") {
|
||||
(void)builder->SetColumnsToMap(ToStringVector(value));
|
||||
}
|
||||
if (key == "pad_info") {
|
||||
} else if (key == "pad_info") {
|
||||
PadInfo pad_info;
|
||||
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
|
||||
(void)builder->SetPaddingMap(pad_info, true);
|
||||
|
|
|
@ -81,6 +81,40 @@ Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IteratorBase::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
|
||||
|
||||
TensorRow curr_row;
|
||||
|
||||
RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
|
||||
RETURN_OK_IF_TRUE(curr_row.empty());
|
||||
|
||||
size_t num_cols = curr_row.size(); // num_cols is non-empty.
|
||||
if (col_name_id_map_.empty()) col_name_id_map_ = this->GetColumnNameMap();
|
||||
// order the column names according to their ids
|
||||
if (column_order_.empty()) {
|
||||
const int32_t invalid_col_id = -1;
|
||||
column_order_.resize(num_cols, {std::string(), invalid_col_id});
|
||||
for (const auto itr : col_name_id_map_) {
|
||||
int32_t ind = itr.second;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
|
||||
column_order_[ind] = std::make_pair(itr.first, ind);
|
||||
}
|
||||
// error check, make sure the ids in col_name_id_map are continuous and starts from 0
|
||||
for (const auto &col : column_order_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous.");
|
||||
}
|
||||
}
|
||||
|
||||
vec->reserve(num_cols);
|
||||
|
||||
for (const auto &col : column_order_) {
|
||||
vec->emplace_back(std::make_pair(col.first, curr_row[col.second]));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the DatasetIterator
|
||||
DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree)
|
||||
: IteratorBase(),
|
||||
|
|
|
@ -19,12 +19,14 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -61,6 +63,11 @@ class IteratorBase {
|
|||
// @return A unordered map from column name to shared pointer to Tensor.
|
||||
Status GetNextAsMap(TensorMap *out_map);
|
||||
|
||||
/// \breif return column_name, tensor pair in the order of its column id.
|
||||
/// \param[out] vec
|
||||
/// \return Error code
|
||||
Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec);
|
||||
|
||||
// Getter
|
||||
// @return T/F if this iterator is completely done after getting an eof
|
||||
bool eof_handled() const { return eof_handled_; }
|
||||
|
@ -73,6 +80,7 @@ class IteratorBase {
|
|||
std::unique_ptr<DataBuffer> curr_buffer_; // holds the current buffer
|
||||
bool eof_handled_; // T/F if this op got an eof
|
||||
std::unordered_map<std::string, int32_t> col_name_id_map_;
|
||||
std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id
|
||||
};
|
||||
|
||||
// The DatasetIterator derived class is for fetching rows off the end/root of the execution tree.
|
||||
|
|
|
@ -150,12 +150,15 @@ def check_columns(columns, name):
|
|||
Exception: when the value is not correct, otherwise nothing.
|
||||
"""
|
||||
type_check(columns, (list, str), name)
|
||||
if isinstance(columns, list):
|
||||
if isinstance(columns, str):
|
||||
if not columns:
|
||||
raise ValueError("{0} should not be an empty str".format(name))
|
||||
elif isinstance(columns, list):
|
||||
if not columns:
|
||||
raise ValueError("{0} should not be empty".format(name))
|
||||
for i, column_name in enumerate(columns):
|
||||
if not column_name:
|
||||
raise ValueError("{0}[{1}] should not be empty".format(name, i))
|
||||
raise ValueError("{0}[{1}] should not be empty.".format(name, i))
|
||||
|
||||
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
|
||||
type_check_list(columns, (str,), col_names)
|
||||
|
|
|
@ -503,17 +503,13 @@ def check_batch(method):
|
|||
|
||||
if input_columns is not None:
|
||||
check_columns(input_columns, "input_columns")
|
||||
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
|
||||
raise ValueError("the signature of per_batch_map should match with input columns")
|
||||
|
||||
if (per_batch_map is None) != (input_columns is None):
|
||||
# These two parameters appear together.
|
||||
raise ValueError("per_batch_map and input_columns need to be passed in together.")
|
||||
|
||||
if input_columns is not None:
|
||||
if not input_columns: # Check whether input_columns is empty.
|
||||
raise ValueError("input_columns can not be empty")
|
||||
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
|
||||
raise ValueError("the signature of per_batch_map should match with input columns")
|
||||
|
||||
if output_columns is not None:
|
||||
raise ValueError("output_columns is currently not implemented.")
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
from util import save_and_check_tuple
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
@ -155,3 +157,27 @@ def test_case_map_project_map_project():
|
|||
|
||||
filename = "project_alternate_parallel_inline_result.npz"
|
||||
save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_column_order():
|
||||
"""test the output dict has maintained an insertion order."""
|
||||
|
||||
def gen_3_cols(num):
|
||||
for i in range(num):
|
||||
yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2]))
|
||||
|
||||
def test_config(num, col_order):
|
||||
dst = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]).batch(batch_size=num)
|
||||
dst = dst.project(col_order)
|
||||
res = dict()
|
||||
for item in dst.create_dict_iterator(num_epochs=1):
|
||||
res = item
|
||||
return res
|
||||
|
||||
assert list(test_config(1, ["col3", "col2", "col1"]).keys()) == ["col3", "col2", "col1"]
|
||||
assert list(test_config(2, ["col1", "col2", "col3"]).keys()) == ["col1", "col2", "col3"]
|
||||
assert list(test_config(3, ["col2", "col3", "col1"]).keys()) == ["col2", "col3", "col1"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_column_order()
|
||||
|
|
|
@ -190,14 +190,13 @@ def test_random_affine_py_exception_non_pil_images():
|
|||
Test RandomAffine: input img is ndarray and not PIL, expected to raise RuntimeError
|
||||
"""
|
||||
logger.info("test_random_affine_exception_negative_degrees")
|
||||
dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3)
|
||||
dataset = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3, num_parallel_workers=3)
|
||||
try:
|
||||
transform = mindspore.dataset.transforms.py_transforms.Compose([py_vision.ToTensor(),
|
||||
py_vision.RandomAffine(degrees=(15, 15))])
|
||||
dataset = dataset.map(operations=transform, input_columns=["image"], num_parallel_workers=3,
|
||||
python_multiprocessing=True)
|
||||
dataset = dataset.map(operations=transform, input_columns=["image"], num_parallel_workers=3)
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1):
|
||||
break
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Pillow image" in str(e)
|
||||
|
|
Loading…
Reference in New Issue