[MD] support python dictionary in pipeline

This commit is contained in:
mohammad 2023-02-25 14:08:15 -05:00
parent be05eb75dc
commit b00d716243
13 changed files with 867 additions and 54 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -99,6 +99,13 @@ PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) {
.def("__str__", &Tensor::ToString)
.def("shape", &Tensor::shape)
.def("type", &Tensor::type)
.def("as_python",
[](py::object &t) {
auto &tensor = py::cast<Tensor &>(t);
py::dict res;
THROW_IF_ERROR(tensor.GetDataAsPythonObject(&res));
return res;
})
.def("as_array", [](py::object &t) {
auto &tensor = py::cast<Tensor &>(t);
if (tensor.type().IsString()) {
@ -124,6 +131,7 @@ PYBIND_REGISTER(DataType, 0, ([](const py::module *m) {
(void)py::class_<DataType>(*m, "DataType")
.def(py::init<std::string>())
.def(py::self == py::self)
.def("is_python", &DataType::IsPython)
.def("__str__", &DataType::ToString)
.def("__deepcopy__", [](py::object &t, const py::dict &memo) { return t; });
}));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -118,6 +118,10 @@ DataType::DataType(const std::string &type_str) {
type_ = DE_STRING;
} else if (type_str == "bytes") {
type_ = DE_BYTES;
#ifdef ENABLE_PYTHON
} else if (type_str == "python") {
type_ = DE_PYTHON;
#endif
} else {
type_ = DE_UNKNOWN;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -52,6 +52,7 @@ class DataType {
DE_FLOAT64,
DE_STRING,
DE_BYTES,
DE_PYTHON,
NUM_OF_TYPES
};
@ -80,7 +81,8 @@ class DataType {
{"float32", 4, "float32", py::format_descriptor<float>::format(), CV_32F}, // DE_FLOAT32
{"float64", 8, "double", py::format_descriptor<double>::format(), CV_64F}, // DE_FLOAT64
{"string", 0, "str", "U", kCVInvalidType}, // DE_STRING
{"bytes", 0, "bytes", "S", kCVInvalidType} // DE_BYTES
{"bytes", 0, "bytes", "S", kCVInvalidType}, // DE_BYTES
{"python", 0, "object", "O", kCVInvalidType} // DE_PYTHON
};
#else
#if !defined(ENABLE_ANDROID) || defined(ENABLE_CLOUD_FUSION_INFERENCE)
@ -242,6 +244,8 @@ class DataType {
bool IsString() const { return type_ == DataType::DE_STRING || type_ == DataType::DE_BYTES; }
bool IsPython() const { return type_ == DataType::DE_PYTHON; }
Type value() const { return type_; }
private:

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -72,6 +72,12 @@ Tensor::Tensor(Tensor &&other) noexcept
data_(other.GetMutableBuffer()),
data_end_(other.data_end_),
data_allocator_(std::move(other.data_allocator_)) {
#ifdef ENABLE_PYTHON
if (type_.value() == DataType::DE_PYTHON) {
py::gil_scoped_acquire gil_acquire;
python_dict_ = (other.python_dict_);
}
#endif
other.Invalidate();
}
@ -83,6 +89,12 @@ Tensor &Tensor::operator=(Tensor &&other) noexcept {
data_end_ = other.data_end_;
data_allocator_ = std::move(other.data_allocator_);
yuv_shape_ = other.yuv_shape_;
#ifdef ENABLE_PYTHON
if (type_.value() == DataType::DE_PYTHON) {
py::gil_scoped_acquire gil_acquire;
python_dict_ = (other.python_dict_);
}
#endif
other.Invalidate();
}
return *this;
@ -231,6 +243,21 @@ Status Tensor::CreateFromNpArray(const py::array &arr, std::shared_ptr<Tensor> *
}
return Status::OK();
}
Status Tensor::CreateFromPythonObject(py::object obj, std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::vector<dsize_t> shape{};
DataType type = DataType(DataType::DE_PYTHON);
const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator();
*out = std::allocate_shared<Tensor>(*alloc, TensorShape({0}), type);
{
py::gil_scoped_acquire gil_acquire;
(*out)->python_dict_ = obj;
}
CHECK_FAIL_RETURN_UNEXPECTED(out != nullptr, "Failed to create a tensor for python object.");
return Status::OK();
}
#endif
#ifndef ENABLE_ANDROID
@ -384,9 +411,33 @@ Tensor::~Tensor() {
data_end_ = nullptr;
}
}
}
#ifdef ENABLE_PYTHON
try {
if (Py_IsInitialized()) {
if (static_cast<bool>(python_dict_)) { // if it contains data
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (python_dict_.ref_count() == 1) { // if we aren't referencing it anywhere else
python_dict_.dec_ref(); // manually set the ref count to zero (to be garbage collected by Python)
python_dict_ = py::none(); // wrapper now pointing to a meaningful thing (added to avoid a segfault)
}
}
}
} catch (py::error_already_set &e) {
// ignore exceptions as everything could be shutting down at this point
}
#endif
} // namespace dataset
bool Tensor::operator==(const Tensor &rhs) const {
#ifdef ENABLE_PYTHON
if (type_.value() == DataType::DE_PYTHON) { // we are holding a python object
if (static_cast<bool>(python_dict_) && static_cast<bool>(rhs.python_dict_) && python_dict_ == rhs.python_dict_) {
return true;
}
return false;
}
#endif
// 1. different shape 2. different type 3. one data_ is nullptr and the other is not
if (shape_ != rhs.shape() || type_ != rhs.type_ || (data_ == nullptr && rhs.data_ != nullptr) ||
(data_ != nullptr && rhs.data_ == nullptr)) {
@ -473,6 +524,15 @@ void Tensor::Print(std::ostream &out) const {
out << ", Type: " << type_ << ")\n";
if (data_) {
PrintRecursive(out, 0, std::vector<dsize_t>{});
#ifdef ENABLE_PYTHON
} else if (static_cast<bool>(python_dict_)) {
std::string s;
{
py::gil_scoped_acquire gil_acquire;
s = py::str(python_dict_);
}
out << s;
#endif
} else {
out << "[Data area is null]";
}
@ -510,6 +570,12 @@ void Tensor::Invalidate() {
data_ = nullptr;
data_end_ = nullptr;
data_allocator_ = nullptr;
#ifdef ENABLE_PYTHON
if (type_.value() == DataType::DE_PYTHON) {
py::gil_scoped_acquire gil_acquire;
python_dict_ = py::none();
}
#endif
}
template <typename T>
@ -854,6 +920,15 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
}
return Status::OK();
}
Status Tensor::GetDataAsPythonObject(py::dict *data) {
RETURN_UNEXPECTED_IF_NULL(data);
{
py::gil_scoped_acquire gil_acquire;
*data = python_dict_;
}
return Status::OK();
}
#endif
void Tensor::Squeeze() { shape_ = shape_.Squeeze(); }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -126,6 +126,12 @@ class DATASET_API Tensor {
/// \param[out] out Created tensor
/// \return Status Code
static Status CreateFromNpArray(const py::array &arr, TensorPtr *out);
/// Helper function to create a tensor from a Python dictionary object
/// \param[in] obj pybind11 wrapper for Python dictionary object
/// \param[out] out Created Tensor
/// \return Status
static Status CreateFromPythonObject(py::object obj, TensorPtr *out);
#endif
#ifndef ENABLE_ANDROID
@ -552,6 +558,12 @@ class DATASET_API Tensor {
}
static Status GetBufferInfo(Tensor *t, py::buffer_info *out);
/// Returns the Python dictionary stored in the tensor
/// \param[out] data this data is the location of Python data (pybind11 wrapper)
/// \return Status code
Status GetDataAsPythonObject(py::dict *data);
#endif
Status SetYuvShape(const uint32_t &width, const uint32_t &widthStride, const uint32_t &height,
@ -837,6 +849,11 @@ class DATASET_API Tensor {
/// shape for interpretation of YUV image
std::vector<uint32_t> yuv_shape_;
#ifdef ENABLE_PYTHON
/// Store python dictionary wrapper
py::object python_dict_;
#endif
private:
friend class DETensor;

View File

@ -234,6 +234,46 @@ Status BatchOp::ConvertRowsToTensor(const std::unique_ptr<TensorQTable> *src, st
std::to_string(col) + ", expected type for this column is:" + type1 + ", got type:" + type2);
}
}
#ifdef ENABLE_PYTHON
} else if (first_type.IsPython()) {
// handle python dictionary columns differently:
// each column of new batch will be a python dictionary where each key stores
// all values of corresponding rows in a python list.
{
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
py::dict new_dict;
size_t num_keys = 0;
for (size_t j = 0; j < batch_size; j++) {
std::shared_ptr<Tensor> old_tensor = (*src)->at(j).at(col); // row j, column col
py::dict old_dict;
RETURN_IF_NOT_OK(old_tensor->GetDataAsPythonObject(&old_dict));
if (j == 0) {
num_keys = py::len(old_dict);
for (auto key_val : old_dict) {
py::list li;
li.append(key_val.second);
new_dict[key_val.first] = li;
}
} else {
CHECK_FAIL_RETURN_UNEXPECTED(
num_keys == py::len(old_dict),
"Failed to create a batch since number of key/value pairs in dictionaries do not match. First row: " +
std::to_string(num_keys) + ", current row: " + std::to_string(py::len(new_dict)));
for (auto key_val : old_dict) {
CHECK_FAIL_RETURN_UNEXPECTED(new_dict.contains(key_val.first),
"Python dictionary keys do not match when creating a batch: " +
key_val.first.cast<std::string>() + " was not found in previous rows.");
py::list li = new_dict[key_val.first];
li.append(key_val.second);
}
}
}
RETURN_IF_NOT_OK(Tensor::CreateFromPythonObject(new_dict, &new_tensor));
}
#endif
} else { // handle string column differently
std::vector<std::string> strings;
for (dsize_t j = 0; j < batch_size; j++) {
@ -419,14 +459,20 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
try {
// Prepare batch map call back parameters
py::tuple input_args(input->size() + 1);
for (size_t i = 0; i < input->size(); i++) {
std::vector<py::array> np_batch;
for (std::shared_ptr<Tensor> t : input->at(i)) {
py::array np_array;
RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array));
np_batch.push_back(std::move(np_array));
for (size_t i = 0; i < input->size(); i++) { // iterate over columns
std::vector<py::object> column_batch;
for (std::shared_ptr<Tensor> t : input->at(i)) { // iterate over rows
if (t->type().IsPython()) {
py::dict new_data;
RETURN_IF_NOT_OK(t->GetDataAsPythonObject(&new_data));
column_batch.push_back(new_data);
} else {
py::array np_array;
RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array));
column_batch.push_back(std::move(np_array));
}
}
input_args[i] = np_batch;
input_args[i] = column_batch;
}
input_args[input->size()] = info;
// Invoke batch map func
@ -441,32 +487,45 @@ Status BatchOp::InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBat
"should be " +
std::to_string(out_col_names_.size()) +
" , but got: " + std::to_string(ret_tuple.size()));
bool all_array = true;
bool all_array_or_dict = true;
for (size_t i = 0; i < ret_tuple.size(); i++) {
if (!py::isinstance<py::array>(ret_tuple[i])) {
all_array = false;
if (!py::isinstance<py::array>(ret_tuple[i]) && !py::isinstance<py::dict>(ret_tuple[i])) {
all_array_or_dict = false;
break;
}
}
*concat_batch = all_array;
*concat_batch = all_array_or_dict;
for (size_t i = 0; i < ret_tuple.size(); i++) {
TensorRow output_batch;
// If user returns a type that is neither a list nor an array, issue a error msg.
if (!py::isinstance<py::list>(ret_tuple[i])) {
// If user returns a type that is neither a list nor a Python dictionary, issue a error msg.
if (!py::isinstance<py::list>(ret_tuple[i]) && !py::isinstance<py::dict>(ret_tuple[i])) {
MS_LOG(INFO) << "column: " << out_col_names_[i]
<< " returned by per_batch_map is not a list, this could lead to conversion failure.";
<< " returned by per_batch_map is not a list nor a Python dict, "
<< "this could lead to conversion failure.";
}
if (*concat_batch) {
std::shared_ptr<Tensor> out;
// If concat batch rows, the batch map function result should be in 1 row.
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(ret_tuple[i]), &out));
std::shared_ptr<Tensor> out;
if (py::isinstance<py::dict>(ret_tuple[i])) {
RETURN_IF_NOT_OK(Tensor::CreateFromPythonObject(py::cast<py::object>(ret_tuple[i]), &out));
} else {
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(ret_tuple[i]), &out));
}
output_batch.push_back(std::move(out));
} else {
CHECK_FAIL_RETURN_UNEXPECTED(
!py::isinstance<py::dict>(ret_tuple[i]),
"Failed to convert rows: mismatched types returned from per_batch_map function. If different types are "
"returned, all of them should be convertible to Python lists. Got: Python dict");
py::list output_list = py::cast<py::list>(ret_tuple[i]);
for (size_t j = 0; j < output_list.size(); j++) {
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
if (py::isinstance<py::dict>(output_list[j])) {
RETURN_IF_NOT_OK(Tensor::CreateFromPythonObject(py::cast<py::object>(output_list[j]), &out));
} else {
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast<py::array>(output_list[j]), &out));
}
output_batch.push_back(std::move(out));
}
}

View File

@ -136,7 +136,8 @@ Status DataQueueOp::FilterMetadata(TensorRow *row) const {
Status DataQueueOp::CheckExceptions(const TensorRow &row) const {
// this method checks if the row meets the conditions to be sent to TDT
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Invalid datatype, cannot send string data to device.");
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(),
"Invalid datatype, cannot send string, or Python dict to device.");
CHECK_FAIL_RETURN_UNEXPECTED(item->HasData(), "Invalid data, the data send to device is null.");
}
return Status::OK();

View File

@ -99,7 +99,7 @@ Status GeneratorOp::Init() {
Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) {
if (!py::isinstance<py::tuple>(py_data)) {
RETURN_STATUS_ERROR(StatusCode::kMDPyFuncException,
"Invalid python function, the 'source' of 'GeneratorDataset' should return a tuple of NumPy "
"Invalid Python function, the 'source' of 'GeneratorDataset' should return a tuple of NumPy "
"arrays, but got " +
std::string(py_data.get_type().str()));
}
@ -108,7 +108,7 @@ Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row)
if (py_row.size() != column_names_.size()) {
RETURN_STATUS_ERROR(
StatusCode::kMDPyFuncException,
"Invalid python function, the 'source' of 'GeneratorDataset' should return same number of NumPy arrays as "
"Invalid Python function, the 'source' of 'GeneratorDataset' should return same number of NumPy arrays as "
"specified in column_names, the size of column_names is:" +
std::to_string(column_names_.size()) +
" and number of returned NumPy array is:" + std::to_string(py_row.size()));
@ -116,18 +116,23 @@ Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row)
// Iterate over two containers simultaneously for memory copy
for (int i = 0; i < py_row.size(); ++i) {
py::object ret_py_ele = py_row[i];
if (!py::isinstance<py::array>(ret_py_ele)) {
RETURN_STATUS_ERROR(StatusCode::kMDPyFuncException,
"Invalid python function, 'GeneratorDataset' should return a tuple of NumPy arrays, "
"but got " +
std::string(ret_py_ele.get_type().str()));
if (!py::isinstance<py::array>(ret_py_ele) && !py::isinstance<py::dict>(ret_py_ele)) {
RETURN_STATUS_ERROR(
StatusCode::kMDPyFuncException,
"Invalid Python function, 'GeneratorDataset' should return a tuple of NumPy arrays or dictionaries, "
"but got " +
std::string(ret_py_ele.get_type().str()));
}
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast<py::array>(), &tensor));
if (py::isinstance<py::dict>(ret_py_ele)) {
RETURN_IF_NOT_OK(Tensor::CreateFromPythonObject(ret_py_ele.cast<py::dict>(), &tensor));
} else {
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast<py::array>(), &tensor));
}
if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) &&
(column_types_[i] != tensor->type())) {
RETURN_STATUS_ERROR(StatusCode::kMDPyFuncException,
"Invalid python function, type of returned data in 'GeneratorDataset' should be same with "
"Invalid Python function, type of returned data in 'GeneratorDataset' should be same with "
"specified column_types, but the type of returned data: " +
std::string(ret_py_ele.get_type().str()) +
", specified column type: " + column_types_[i].ToString());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,6 +36,16 @@ Status ConvertNumpyToTensor(const py::object &py_obj, TensorRow *output) {
return Status::OK();
}
Status ConvertPythonToTensor(py::object py_obj, TensorRow *output) {
// Python objects such as dictionary are converted to a tensor
// Note that the tensor will hold a reference to the python object while
// the python object will be kept alive in Python layer.
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Tensor::CreateFromPythonObject(py_obj, &out));
output->push_back(out);
return Status::OK();
}
Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
Status ret = Status(StatusCode::kSuccess, "PyFunc Call Succeed");
@ -48,14 +58,20 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
}
try {
// Transform input tensor vector into numpy array vector
py::tuple input_args(input.size());
py::object ret_py_obj;
if (input.size() > 0) {
py::tuple input_args(input.size());
for (size_t i = 0; i < input.size(); i++) {
py::array new_data;
RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
// possible memcpy here
input_args[i] = new_data;
if (input.at(i)->type().IsPython()) {
py::dict new_data;
RETURN_IF_NOT_OK(input.at(i)->GetDataAsPythonObject(&new_data));
input_args[i] = new_data;
} else {
py::array new_data;
RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
// possible memcpy here
input_args[i] = new_data;
}
}
// Invoke python function
ret_py_obj = this->py_func_ptr_(*input_args);
@ -78,15 +94,23 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
py::object ret_py_ele = ret_py_tuple[i];
// Object is none if pyfunc timeout
if (ret_py_ele.is_none()) {
MS_LOG(INFO) << "Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
"True, it maybe due to pyfunc execution timeout.";
MS_LOG(INFO) << "Expected pyfunc to return NumPy array(s) or Python dict(s), but got None. "
"If python_multiprocessing is True, it may be due to pyfunc execution timeout.";
goto TimeoutError;
} else if (py::isinstance<py::dict>(ret_py_ele)) {
RETURN_IF_NOT_OK(ConvertPythonToTensor(ret_py_ele, output));
} else {
RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_ele, output));
}
RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_ele, output));
}
} else {
// In case of a n-1 mapping, the return value will be a numpy array
RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_obj, output));
// In case of a n-1 mapping, the return value will be a numpy array or a python object
// Note that for Python dictionaries, only a reference will be stored in tensor.
if (py::isinstance<py::dict>(ret_py_obj)) {
RETURN_IF_NOT_OK(ConvertPythonToTensor(ret_py_obj, output));
} else {
RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_obj, output));
}
}
}
} catch (const py::error_already_set &e) {

View File

@ -129,11 +129,8 @@ def _fill_worker_indices(workers, indices, idx):
def _convert_row(row):
"""
Convert Op return value to numpy
Convert Op return value to numpy, or keep as a dict (if already a dict)
"""
if isinstance(row, dict):
raise TypeError("Input data is expected to be " \
"int, float, str, bytes, numpy.ndarray, Tensor or list/tuple of them, but got dict.")
# convert single item to np.array
prim_type = (int, float, str, bytes, np.ndarray, Tensor)
@ -147,6 +144,9 @@ def _convert_row(row):
"int or float or str, but got {}.".format(item.dtype))
return tuple([item])
if isinstance(row, dict):
return tuple([row])
value = []
# convert each item to np.array
idx = 0
@ -155,8 +155,7 @@ def _convert_row(row):
if isinstance(x, Tensor): # mindspore.Tensor
value.append(x.asnumpy())
elif isinstance(x, dict):
raise TypeError("The {}th item of input data is expected to be " \
"int, float, str, bytes, numpy.ndarray, Tensor, but got dict.".format(idx))
value.append(x)
else:
item = np.array(x, copy=False)
if item.dtype == 'object':

View File

@ -1,4 +1,4 @@
# Copyright 2019-2022 Huawei Technologies Co., Ltd
# Copyright 2019-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in iterators.
"""
"""Built-in iterators"""
from abc import abstractmethod
from copy import deepcopy
import collections.abc
import json
import os
import signal
@ -183,12 +184,36 @@ class Iterator:
"""
self._iterator.Reset(step, epoch)
def __convert_python_to_tensor(self, obj):
"""
Attempts to recursively convert a python object to tensor(s).
Args:
obj (any): the python object to be converted
"""
if isinstance(obj, (np.ndarray, int, float, bool, str)):
if self._do_copy:
return Tensor(np.asarray(obj))
return Tensor.from_numpy(np.asarray(obj))
if isinstance(obj, dict):
return {key: self.__convert_python_to_tensor(val) for key, val in obj.items()}
if isinstance(obj, collections.abc.Iterable):
return [self.__convert_python_to_tensor(item) for item in obj]
# if we can't convert it to Tensor, return the object as is
if self._do_copy:
return deepcopy(obj)
return obj
def _transform_md_to_output(self, t):
if self._output_numpy:
if t.type().is_python():
return t.as_python()
return t.as_array()
return self._transform_md_to_tensor(t)
def _transform_md_to_tensor(self, t):
if t.type().is_python():
return self.__convert_python_to_tensor(t.as_python())
array = t.as_array()
if self._do_copy:
return Tensor(array)

View File

@ -0,0 +1,94 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.train import Model
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.common import Tensor
# pylint: disable=no-value-for-parameter
def create_dataset(size, needs_batch):
"""
Create dataset for train or test
"""
def my_func(x):
arr = np.zeros((2, 2))
return ({"originally_tensor": x, "originally_numpy": arr, "originally_dict": {"dd": x},
"originally_int": 1, "originally_bool": True, "originally_float": 1.0}, x, arr)
data_path = os.path.join("/home/workspace/mindspore_dataset/mnist", "train")
data = ds.MnistDataset(data_path, num_parallel_workers=8, num_samples=size)
data = data.project("image")
data = data.map(operations=my_func, input_columns=["image"],
output_columns=["dict", "originally_tensor", "originally_numpy"])
if needs_batch:
data = data.batch(2)
return data
def create_model():
"""
Define and return a simple model
"""
class Net(nn.Cell):
def construct(self, x, y, z):
assert isinstance(x, dict)
assert isinstance(y, Tensor)
assert isinstance(z, Tensor)
return x
net = Net()
model_ = Model(net)
return model_
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.forked
@pytest.mark.parametrize("needs_batch", (False, True))
def test_python_dict_in_pipeline(needs_batch):
"""
Feature: Dataset pipeline contains a Python dict object
Description: A dict object is created and sent to the model by dataset pipeline
Expectation: Python dict object is successfully sent to the model
"""
logger.info("test_python_dict_in_pipeline - dict object testing")
num_epochs = 2
dataset_size = 50
data = create_dataset(dataset_size, needs_batch)
model = create_model()
# non-sink mode supports python dictionary
model.train(num_epochs, data, dataset_sink_mode=False)
# sink mode doesn't support python dict as input
with pytest.raises(RuntimeError) as error_info:
model.train(num_epochs, data, dataset_sink_mode=True)
assert "The python type <class 'numpy.object_'> cannot be converted to MindSpore type." in str(
error_info.value)
if __name__ == '__main__':
test_python_dict_in_pipeline(True)

View File

@ -0,0 +1,498 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test generic support of Python dictionaries in dataset pipeline
"""
import gc
from time import sleep
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.common import Tensor
def index_generator(ds_size):
for i in range(ds_size):
yield i
def dict_generator(ds_size):
for i in range(ds_size):
yield {'integer': i, 'boolean': True, 'string': "MY_EMPTY_STR", "tuple": (1, 2, 3)}
def simple_pyfunc(x):
return x
def build_dict(x):
return {"integer": x, "a": x**2, "b": 1}
def remove_dict(x):
return x["integer"]
def remove_dict_wrong_key(x):
return x["non-existing"]
def build_exp_dict(x):
return {"value": np.power(x, 1), "square": np.power(x, 2), "cube": np.power(x, 3)}
def create_dict_batch(col1, batch_info):
ret = [build_exp_dict(x) for x in col1]
return (ret,)
def modify_dict_batch(col1, batch_info):
def convert(x):
new_dict = x
new_dict["integer"] = np.power(new_dict["integer"], 2)
new_dict["boolean"] = 1
return new_dict
new_dicts = [convert(x) for x in col1]
return (new_dicts,)
@pytest.mark.parametrize("my_iterator", ("tuple", "dict"))
@pytest.mark.parametrize("output_numpy", (False, True))
def test_dict_generator(my_iterator, output_numpy):
"""
Feature: Dataset pipeline creates a Python dict object using a generator operation.
Description: Values maintained in the dict object are converted to Tensor appropriately.
Expectation: Python dict object is successfully maintained and converted in the dataset pipeline.
"""
logger.info("test_dict_generator -- Generator(dicts) --> rename()")
dataset_size = 5
data1 = ds.GeneratorDataset(dict_generator(dataset_size), ["col1"])
if my_iterator == "tuple":
itr = data1.create_tuple_iterator(
num_epochs=1, output_numpy=output_numpy)
else:
itr = data1.create_dict_iterator(
num_epochs=1, output_numpy=output_numpy)
for d in itr:
gc.collect() # to make sure python objects are not garbage collected
if my_iterator == "tuple":
data = d[0]
else:
data = d["col1"]
assert isinstance(data, dict)
if output_numpy:
assert isinstance(data["integer"], int)
assert isinstance(data["boolean"], bool)
assert isinstance(data["string"], str)
assert isinstance(data["tuple"], tuple)
else: # tensor
assert isinstance(data["integer"], Tensor)
assert isinstance(data["boolean"], Tensor)
assert isinstance(data["string"], Tensor)
assert isinstance(data["tuple"], list)
assert isinstance(data["tuple"][0], Tensor)
def test_dict_generator_map_1():
"""
Feature: Dataset pipeline contains a Python dict object.
Description: Generator operation creates dictionaries while the next operation (map) removes them.
Expectation: Python dict objects are successfully created, maintained, and deleted in the dataset pipeline.
"""
logger.info("test_dict_generator_map_1 -- Generator(dicts) --> map(remove_dicts) --> rename()")
dataset_size = 5
data1 = ds.GeneratorDataset(lambda: dict_generator(dataset_size), ["col1"])
data1 = data1.map(remove_dict)
data1 = data1.rename(["col1"], ["renamed_col1"])
count = 0
itr = data1.create_dict_iterator(num_epochs=2, output_numpy=True)
for _ in range(2):
for i, d in enumerate(itr):
gc.collect()
count += 1
assert isinstance(d["renamed_col1"], np.ndarray)
assert d["renamed_col1"] == np.array([i])
assert count == 10
def test_dict_generator_map_2():
"""
Feature: Dataset pipeline contains a Python dict object.
Description: Generator operation creates dictionaries while the following map operation's pyfunc accesses them.
Expectation: Python dict objects are successfully created, maintained, and sent to user.
"""
logger.info(
"test_dict_generator_map_2 -- Generator(dicts) --> map(simple_pyfunc) --> rename()")
dataset_size = 5
data1 = ds.GeneratorDataset(lambda: dict_generator(dataset_size), ["col1"])
data1 = data1.map(simple_pyfunc)
data1 = data1.rename(["col1"], ["renamed_col1"])
count = 0
itr = data1.create_dict_iterator(num_epochs=2, output_numpy=True)
for _ in range(2):
for d in itr:
gc.collect()
count += 1
assert isinstance(d["renamed_col1"], dict)
assert isinstance(d["renamed_col1"]["integer"], int)
assert isinstance(d["renamed_col1"]["boolean"], bool)
assert isinstance(d["renamed_col1"]["string"], str)
assert isinstance(d["renamed_col1"]["tuple"], tuple)
assert count == 10
def test_dict_generator_map_3():
"""
Feature: Dataset pipeline contains a Python dict object.
Description: Generator operation creates dictionaries while the following map operation's pyfunc
tries to access a non-existing key.
Expectation: Appropriate error is raised in the dataset pipeline.
"""
logger.info(
"test_dict_generator_map_3 -- Generator(dicts) --> map(remove_dict_wrong_key) --> rename()")
dataset_size = 5
data1 = ds.GeneratorDataset(dict_generator(dataset_size), ["col1"])
data1 = data1.map(remove_dict_wrong_key)
data1 = data1.rename(["col1"], ["renamed_col1"])
with pytest.raises(RuntimeError) as error_info:
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "KeyError" in str(error_info.value)
def test_dict_generator_batch_1():
"""
Feature: Dataset pipeline contains a Python dict object.
Description: Batch operation automatically constructs appropriate arrays for each element in dictionaries.
Expectation: Python dict objects are successfully created, maintained, and sent to user.
"""
logger.info(
"test_dict_generator_batch_1 -- Generator(dicts) --> rename() --> batch()")
dataset_size = 5
data1 = ds.GeneratorDataset(lambda: dict_generator(dataset_size), ["col1"])
data1 = data1.rename(["col1"], ["renamed_col1"])
data1 = data1.batch(2, drop_remainder=True)
itr = data1.create_dict_iterator(num_epochs=2, output_numpy=True)
count = 0
for _ in range(2):
for d in itr:
gc.collect()
count += 1
assert isinstance(d["renamed_col1"], dict)
assert isinstance(d["renamed_col1"]["integer"], list)
assert isinstance(d["renamed_col1"]["integer"][0], int)
assert isinstance(d["renamed_col1"]["boolean"], list)
assert isinstance(d["renamed_col1"]["boolean"][0], bool)
assert isinstance(d["renamed_col1"]["string"], list)
assert isinstance(d["renamed_col1"]["string"][0], str)
assert isinstance(d["renamed_col1"]["tuple"], list)
assert isinstance(d["renamed_col1"]["tuple"][0], tuple)
assert count == 4
def test_dict_generator_batch_2():
"""
Feature: Dataset pipeline contains a Python dict object.
Description: Batch operation's per_batch_map adds dictionaries to the dataset pipeline.
Expectation: Python dict objects are successfully created, maintained, and sent to user.
"""
# input: int, with per_batch_map creating dict
logger.info(
"test_dict_generator_batch_2 -- Generator() --> batch(create_dict_batch)")
dataset_size = 5
data1 = ds.GeneratorDataset(lambda: index_generator(dataset_size), ["col1"])
data1 = data1.batch(2, per_batch_map=create_dict_batch,
drop_remainder=True)
itr = data1.create_dict_iterator(num_epochs=2, output_numpy=True)
count = 0
for _ in range(2):
for d in itr:
gc.collect()
count += 1
assert isinstance(d["col1"], dict)
assert isinstance(d["col1"]["value"], list)
assert isinstance(d["col1"]["value"][0], np.int64)
assert isinstance(d["col1"]["square"], list)
assert isinstance(d["col1"]["square"][0], np.int64)
assert isinstance(d["col1"]["cube"], list)
assert isinstance(d["col1"]["cube"][0], np.int64)
assert count == 4
def test_dict_generator_batch_3():
"""
Feature: Dataset pipeline contains python dict objects.
Description: Batch operation's per_batch_map modifies existing dict objects in the pipeline.
Expectation: Python dict objects are successfully created, maintained, and sent to user.
"""
logger.info(
"test_dict_generator_batch_3 -- Generator(dict_generator) --> rename() --> batch(modify_dict_batch)")
dataset_size = 5
data1 = ds.GeneratorDataset(lambda: dict_generator(dataset_size), ["col1"])
data1 = data1.rename(["col1"], ["renamed_col1"])
data1 = data1.batch(2, per_batch_map=modify_dict_batch,
drop_remainder=True)
itr = data1.create_dict_iterator(num_epochs=2, output_numpy=True)
counter = 0
for _ in range(2):
for d in itr:
gc.collect()
counter += 1
assert isinstance(d["renamed_col1"], dict)
assert isinstance(d["renamed_col1"]["integer"], list)
assert isinstance(d["renamed_col1"]["integer"][0], np.int64)
assert isinstance(d["renamed_col1"]["boolean"], list)
assert isinstance(d["renamed_col1"]["boolean"][0], int) # modified!
assert isinstance(d["renamed_col1"]["string"], list)
assert isinstance(d["renamed_col1"]["string"][0], str)
assert isinstance(d["renamed_col1"]["tuple"], list)
assert isinstance(d["renamed_col1"]["tuple"][0], tuple)
assert isinstance(d["renamed_col1"]["tuple"][0][0], int)
assert counter == 4
def wrong_batch1(col1, col2, batch_info):
return {"a": 1}, col2 # 1 dict vs list of dicts
def wrong_batch2(col1, col2, batch_info):
return {"a": 1}, {"a"} # 1 dict vs 1 set
def wrong_batch3(col1, col2, batch_info):
return {"a": 1}, [1] # 1 dict vs a list (not a numpy array)
def wrong_batch4(col1, col2, batch_info):
return {"a": 1}, [np.array([1]), np.array([1])] # 1 dict vs list (not a numpy array)
def wrong_batch5(col1, col2, batch_info):
return col1, np.array([1]) # 1 list of dicts vs 1 np (insufficient data in np to split)
@pytest.mark.parametrize("wrong_dict_batch", [wrong_batch1, wrong_batch2, wrong_batch3, wrong_batch4, wrong_batch5])
def test_dict_generator_batch_4(wrong_dict_batch):
"""
Feature: Dataset pipeline contains python dict objects.
Description: Batch operation's per_batch_map modifies existing dict objects in the pipeline.
Expectation: Appropriate error is raised in the dataset pipeline.
"""
logger.info(
"test_dict_generator_batch_4 -- Generator(dict_generator) x 2 --> zip() --> batch()")
dataset_size = 5
data1 = ds.GeneratorDataset(dict_generator(dataset_size), ["col1"])
data2 = ds.GeneratorDataset(dict_generator(dataset_size), ["col2"])
data3 = ds.zip((data1, data2))
data3 = data3.batch(2, per_batch_map=wrong_dict_batch,
drop_remainder=True)
with pytest.raises(RuntimeError) as error_info:
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
# pylint: disable=comparison-with-callable
if wrong_dict_batch == wrong_batch5:
assert "Invalid data, column: col2 expects: 2 rows returned from 'per_batch_map'" in str(error_info.value)
else:
assert "mismatched types returned from per_batch_map" in str(error_info.value)
def correct_batch1(col1, col2, batch_info):
return {"a": 1}, {"a": 2} # 1 dict vs 1 dict
def correct_batch2(col1, col2, batch_info):
return col2, col1 # 1 list of dicts vs 1 list of dicts
def correct_batch3(col1, col2, batch_info):
return {"a": 1}, np.array([1, 2, 3]) # 1 dict vs 1 np
def correct_batch4(col1, col2, batch_info):
return col1, [1, 2] # 1 list of dicts vs 1 list of ints
def correct_batch5(col1, col2, batch_info):
return col1, np.array([1] * len(col2)) # 1 list of dicts vs 1 np (sufficient data)
@pytest.mark.parametrize("my_batch", [correct_batch1, correct_batch2, correct_batch3, correct_batch4, correct_batch5])
def test_dict_generator_batch_5(my_batch):
"""
Feature: Dataset pipeline contains python dict objects.
Description: Batch operation's per_batch_map modifies existing dict objects in the pipeline.
Expectation: Appropriate error is raised in the dataset pipeline.
"""
logger.info(
"test_dict_generator_batch_5 -- Generator(dict_generator) x 2 --> zip() --> batch()")
dataset_size = 5
data1 = ds.GeneratorDataset(dict_generator(dataset_size), ["col1"])
data2 = ds.GeneratorDataset(dict_generator(dataset_size), ["col2"])
data3 = ds.zip((data1, data2))
data3 = data3.batch(2, per_batch_map=my_batch,
drop_remainder=True)
counter = 0
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
counter += 1
assert counter == 2
def test_dict_advanced_pyfunc_dict():
"""
Feature: Dataset pipeline contains Python dict objects.
Description: Various generator, map, and batch operations are used to create and remove dict objects.
Expectation: Python dict objects are successfully created, maintained, and sent to user.
"""
logger.info("test_dict_advanced_pyfunc_dict")
dataset_size = 125
def my_batch_map(x1, x2, x3, y):
return (x1, x2, x3)
def my_delay_f(x1, x2, x3):
gc.collect()
sleep(0.01) # sleep for 0.01s
return (x1, x2, x3)
data1 = ds.GeneratorDataset(index_generator(dataset_size), ["data1"])
data2 = ds.GeneratorDataset(dict_generator(dataset_size), ["data2"])
data3 = ds.GeneratorDataset(index_generator(dataset_size), ["data3"])
data4 = ds.zip((data1, data2, data3))
data4 = data4.map(build_dict, ["data1"])
data4 = data4.map(remove_dict, ["data2"])
data4 = data4.map(build_dict, ["data2"])
data4 = data4.skip(3)
data4 = data4.repeat(2)
data4 = data4.map(build_dict, ["data3"])
data4 = data4.map(remove_dict, ["data2"])
data4 = data4.map(build_dict, ["data2"])
data4 = data4.map(remove_dict, ["data2"])
data4 = data4.take(40)
data4 = data4.map(my_delay_f, ["data1", "data2", "data3"])
data4 = data4.rename(["data1"], ["data1new"])
data4 = data4.batch(2, per_batch_map=my_batch_map)
data4 = data4.batch(2, drop_remainder=False)
count = 0
for d in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
gc.collect()
count += 1
assert len(d) == 3 # 3 columns
assert isinstance(d["data1new"], dict)
np.testing.assert_array_equal(d["data1new"]["b"], np.array([[1, 1], [1, 1]]))
assert count == 10
@pytest.mark.parametrize("my_iterator", ("tuple", "dict"))
@pytest.mark.parametrize("output_numpy", (False, True))
def test_dict_generator_mixed(my_iterator, output_numpy):
"""
Feature: Dataset pipeline creates a Python dict object using a generator operation.
Description: Values maintained in the dict object are converted to Tensor appropriately.
Expectation: Python dict object is successfully maintained and converted in the dataset pipeline.
"""
logger.info("test_dict_generator_mixed -- Generator(dicts) --> rename()")
def mixed_dict_generator(ds_size):
for i in range(ds_size):
yield ({'integer': i, 'boolean': True, 'string': "MY_EMPTY_STR", "tuple": (1, 2, 3)}, True, 4, "String")
dataset_size = 15
data1 = ds.GeneratorDataset(mixed_dict_generator(dataset_size), ["col1", "col2", "col3", "col4"])
if my_iterator == "tuple":
itr = data1.create_tuple_iterator(
num_epochs=1, output_numpy=output_numpy)
else:
itr = data1.create_dict_iterator(
num_epochs=1, output_numpy=output_numpy)
count = 0
for data in itr:
gc.collect() # to make sure python objects are not garbage collected
count += 1
if my_iterator == "tuple":
if output_numpy:
assert isinstance(data[0], dict)
assert isinstance(data[0]["integer"], int)
assert isinstance(data[1], np.ndarray)
assert isinstance(data[2], np.ndarray)
assert isinstance(data[3], np.ndarray)
else: # tensor
assert isinstance(data[0], dict)
assert isinstance(data[0]["integer"], Tensor)
assert isinstance(data[1], Tensor)
assert isinstance(data[2], Tensor)
assert isinstance(data[3], Tensor)
else: # dict iterator
if output_numpy:
assert isinstance(data["col1"], dict)
assert isinstance(data["col1"]["integer"], int)
assert isinstance(data["col2"], np.ndarray)
assert isinstance(data["col3"], np.ndarray)
assert isinstance(data["col4"], np.ndarray)
else: # tensor
assert isinstance(data["col1"], dict)
assert isinstance(data["col1"]["integer"], Tensor)
assert isinstance(data["col2"], Tensor)
assert isinstance(data["col3"], Tensor)
assert isinstance(data["col4"], Tensor)
assert count == 15
def test_dict_generator_nested_dicts():
"""
Feature: Dataset pipeline contains a Python dict object.
Description: Generator operation creates nested dictionaries.
Expectation: Python dict objects are successfully created, maintained, and deleted in the dataset pipeline.
"""
logger.info("test_dict_generator_nested_dicts -- Generator(nested_dicts)")
dataset_size = 5
def nested_dict_generator(ds_size):
for i in range(ds_size):
yield {"integer": i, "dict": {"a": 0, "b": 1}}
data1 = ds.GeneratorDataset(lambda: nested_dict_generator(dataset_size), ["col1"])
count = 0
itr = data1.create_dict_iterator(num_epochs=2, output_numpy=True)
for _ in range(2):
for d in itr:
gc.collect()
count += 1
assert isinstance(d["col1"], dict)
assert isinstance(d["col1"]["integer"], int)
assert isinstance(d["col1"]["dict"], dict)
assert isinstance(d["col1"]["dict"]["a"], int)
assert count == 10
if __name__ == '__main__':
test_dict_generator("tuple", False)
test_dict_generator_map_1()
test_dict_generator_map_2()
test_dict_generator_map_3()
test_dict_generator_batch_1()
test_dict_generator_batch_2()
test_dict_generator_batch_3()
test_dict_generator_batch_4(wrong_batch1)
test_dict_generator_batch_5(correct_batch1)
test_dict_advanced_pyfunc_dict()
test_dict_generator_mixed("tuple", False)
test_dict_generator_nested_dicts()