forked from mindspore-Ecosystem/mindspore
- Bug when empty strings sent to Python
- Support accepting Numpy of str as input - Support batching strings - Core logic of batch&pad is static - Make Pad a utility function
This commit is contained in:
parent
bc7a3a1bef
commit
f837ddc956
|
@ -642,18 +642,8 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
(void)builder->SetColumnsToMap(ToStringVector(value));
|
||||
}
|
||||
if (key == "pad_info") {
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_info;
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
py::tuple tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
float pad_val = tp[1].is_none() ? 0 : ToFloat(tp[1]);
|
||||
(void)pad_info.insert({ToString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info.insert({ToString(p.first), {TensorShape({}), 0}});
|
||||
}
|
||||
}
|
||||
PadInfo pad_info;
|
||||
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
|
||||
(void)builder->SetPaddingMap(pad_info, true);
|
||||
}
|
||||
}
|
||||
|
@ -1166,5 +1156,31 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
std::shared_ptr<Tensor> pad_val = nullptr;
|
||||
if (py::isinstance<py::str>(tp[1])) {
|
||||
std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
Tensor::CreateTensor(&pad_val, std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar()),
|
||||
"Cannot create pad_value Tensor");
|
||||
} else {
|
||||
float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(),
|
||||
DataType(DataType::DE_FLOAT32)),
|
||||
"Cannot create pad_value Tensor");
|
||||
pad_val->SetItemAt<float>({}, pad_val_float);
|
||||
}
|
||||
(void)pad_info->insert({ToString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}});
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -169,6 +169,8 @@ class DEPipeline {
|
|||
// Validate required args passed to storage op.
|
||||
Status ValidateArgStorageOp(const py::dict &args);
|
||||
|
||||
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
||||
|
||||
int batch_size_;
|
||||
int repeat_num_;
|
||||
int num_rows_;
|
||||
|
|
|
@ -138,7 +138,7 @@ DataType DataType::FromNpArray(const py::array &arr) {
|
|||
return DataType(DataType::DE_FLOAT32);
|
||||
} else if (py::isinstance<py::array_t<std::double_t>>(arr)) {
|
||||
return DataType(DataType::DE_FLOAT64);
|
||||
} else if (arr.dtype().kind() == 'S') {
|
||||
} else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') {
|
||||
return DataType(DataType::DE_STRING);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!";
|
||||
|
|
|
@ -229,7 +229,12 @@ Status Tensor::CreateTensorFromNumpyString(std::shared_ptr<Tensor> *ptr, py::arr
|
|||
}
|
||||
arr.resize({arr.size()}); // flatten the py::array so we can iterate once
|
||||
std::vector<std::string> strings;
|
||||
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::bytes>(s)); });
|
||||
|
||||
if (arr.dtype().kind() == 'U') {
|
||||
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::str>(s)); });
|
||||
} else {
|
||||
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::bytes>(s)); });
|
||||
}
|
||||
|
||||
arr.resize(shape); // resize arr back to the original shape
|
||||
|
||||
|
@ -699,6 +704,8 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
|
|||
for (; itr != end<std::string_view>(); itr++) {
|
||||
max = std::max((*itr).length(), max);
|
||||
}
|
||||
// if all strings are empty, numpy stores a byte for each string |S1
|
||||
max = (max == 0 ? 1 : max);
|
||||
uint64_t total_size = shape_.NumOfElements() * max;
|
||||
char *tmp_data = reinterpret_cast<char *>(data_allocator_->allocate(total_size));
|
||||
if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array.");
|
||||
|
@ -708,8 +715,10 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
|
|||
itr = begin<std::string_view>();
|
||||
uint64_t i = 0;
|
||||
for (; itr != end<std::string_view>(); itr++, i++) {
|
||||
ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data.");
|
||||
if (!(*itr).empty()) {
|
||||
ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data.");
|
||||
}
|
||||
}
|
||||
auto strides = shape_.Strides();
|
||||
std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; });
|
||||
|
@ -847,6 +856,21 @@ Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length
|
|||
*length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length
|
||||
return Status::OK();
|
||||
}
|
||||
Status Tensor::CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vector<dsize_t> &index) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0");
|
||||
|
||||
uint8_t type_size = type_.SizeInBytes();
|
||||
size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size;
|
||||
dsize_t src_flat_ind = 0, dst_flat_ind = 0;
|
||||
RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind));
|
||||
RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind));
|
||||
|
||||
const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size;
|
||||
unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -529,6 +529,12 @@ class Tensor {
|
|||
return TensorIterator<T>(data_end_);
|
||||
}
|
||||
|
||||
// Copies the last dimension at `index` from Tensor `src` to this Tensor.
|
||||
// @param src Tensor
|
||||
// @param index vector to the start of the dimension. The last dim should be 0
|
||||
// @return Status
|
||||
Status CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vector<dsize_t> &index);
|
||||
|
||||
protected:
|
||||
// A function that prints Tensor recursively, first called by print
|
||||
// @param out
|
||||
|
|
|
@ -118,7 +118,10 @@ class TensorShape {
|
|||
|
||||
bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); }
|
||||
|
||||
dsize_t operator[](const dsize_t index) const { return raw_shape_[index]; }
|
||||
dsize_t operator[](const dsize_t index) const {
|
||||
if (index < 0) return raw_shape_[raw_shape_.size() + index];
|
||||
return raw_shape_[index];
|
||||
}
|
||||
|
||||
// Return the Shape as a vector
|
||||
// @return
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
|
||||
using float16 = Eigen::half;
|
||||
|
||||
|
@ -53,7 +54,7 @@ Status BatchOp::Builder::SanityCheck() {
|
|||
|
||||
BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &cols_to_map, py::function batch_size_func, py::function batch_map_func,
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_map)
|
||||
PadInfo pad_map)
|
||||
: ParallelOp(num_workers, op_queue_size),
|
||||
start_batch_size_(batch_size),
|
||||
drop_(drop),
|
||||
|
@ -75,10 +76,6 @@ Status BatchOp::operator()() {
|
|||
std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
for (const auto &t : new_row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(t->type().IsNumeric(),
|
||||
"[Batch ERROR] Batch does not support Tensor of type string yet.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(DatasetOp::AssignColMapFromChild()); // must come after the first fetch above
|
||||
int32_t cur_batch_size = 0;
|
||||
RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0)));
|
||||
|
@ -134,49 +131,57 @@ void BatchOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *source_table,
|
||||
const std::unique_ptr<TensorQTable> *dest_table, size_t batch_size) {
|
||||
if ((*source_table)->size() < batch_size || (*source_table)->size() == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Insufficient rows in source_table\n");
|
||||
Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
|
||||
dsize_t batch_size) {
|
||||
if ((*src)->size() != batch_size) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Source table size does not match the batch_size");
|
||||
}
|
||||
TensorRow row = std::move((*source_table)->front());
|
||||
(*source_table)->pop_front();
|
||||
|
||||
if (batch_size == 1) {
|
||||
for (std::shared_ptr<Tensor> tensor : row) {
|
||||
TensorRow row = std::move((*src)->front());
|
||||
(*src)->pop_front();
|
||||
(*dest)->push_back(row);
|
||||
for (const auto &tensor : (*dest)->front()) {
|
||||
RETURN_IF_NOT_OK(tensor->ExpandDim(0));
|
||||
}
|
||||
(*dest_table)->push_back(row);
|
||||
} else { // batch_size > 1
|
||||
std::vector<TensorShape> row_shapes;
|
||||
TensorRow batched_row;
|
||||
for (size_t i = 0; i < row.size(); i++) { // Handle the first row popped
|
||||
row_shapes.push_back(row[i]->shape());
|
||||
std::shared_ptr<Tensor> ts;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(
|
||||
&ts, TensorImpl::kFlexible, row[i]->shape().PrependDim(static_cast<int64_t>(batch_size)), row[i]->type()));
|
||||
batched_row.emplace_back(ts);
|
||||
RETURN_IF_NOT_OK(batched_row[i]->InsertTensor(std::vector<dsize_t>(1, 0), row[i])); // {j} = 0
|
||||
}
|
||||
for (size_t j = 1; j < batch_size; j++) { // Handle the rest of the rows
|
||||
row = std::move((*source_table)->front());
|
||||
(*source_table)->pop_front();
|
||||
for (size_t i = 0; i < row.size(); i++) {
|
||||
if (row[i]->shape() == row_shapes[i]) { // check the newly popped rows have the same dim as the first
|
||||
RETURN_IF_NOT_OK(batched_row[i]->InsertTensor(std::vector<dsize_t>(1, j), row[i]));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TensorRow batched_row;
|
||||
auto num_columns = (*src)->front().size();
|
||||
for (size_t i = 0; i < num_columns; i++) {
|
||||
std::shared_ptr<Tensor> first_tensor = (*src)->at(0).at(i); // first row, column i
|
||||
TensorShape first_shape = first_tensor->shape();
|
||||
DataType first_type = first_tensor->type();
|
||||
TensorShape new_shape = first_shape.PrependDim(static_cast<int64_t>(batch_size));
|
||||
|
||||
std::shared_ptr<Tensor> new_tensor;
|
||||
if (first_type.IsNumeric()) { // numeric tensor
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type));
|
||||
dsize_t j = 0;
|
||||
for (auto row : **src) {
|
||||
std::shared_ptr<Tensor> old_tensor = row.at(i); // row j, column i
|
||||
if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first
|
||||
RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor));
|
||||
} else {
|
||||
std::string column_name;
|
||||
for (auto itr : column_name_id_map_) {
|
||||
if (static_cast<size_t>(itr.second) == i) {
|
||||
column_name = itr.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + column_name);
|
||||
RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i));
|
||||
}
|
||||
}
|
||||
} else { // handle string column differently
|
||||
std::vector<std::string> strings;
|
||||
for (dsize_t j = 0; j < batch_size; j++) {
|
||||
std::shared_ptr<Tensor> old_tensor = (*src)->at(j).at(i);
|
||||
for (auto itr = old_tensor->begin<std::string_view>(); itr != old_tensor->end<std::string_view>(); itr++) {
|
||||
strings.emplace_back(*itr);
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape));
|
||||
}
|
||||
(*dest_table)->emplace_back(batched_row);
|
||||
batched_row.emplace_back(new_tensor);
|
||||
}
|
||||
|
||||
(*dest)->emplace_back(batched_row);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -202,8 +207,8 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
|
|||
Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
|
||||
std::unique_ptr<DataBuffer> *db) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
|
||||
if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
|
||||
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair)); // do padding if needed
|
||||
if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
|
||||
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed
|
||||
(*db) = std::make_unique<DataBuffer>(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone);
|
||||
std::unique_ptr<TensorQTable> dest_table = std::make_unique<TensorQTable>();
|
||||
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size()));
|
||||
|
@ -333,74 +338,27 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou
|
|||
return Status(StatusCode::kOK);
|
||||
}
|
||||
|
||||
Status BatchOp::PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, float pad_val) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
|
||||
if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
|
||||
(*dst) = src; // if no padding, copy the pointer
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
|
||||
auto tensor_type = src->type().value();
|
||||
if (pad_val == 0) { // if pad with zero, don't care what type it is
|
||||
RETURN_IF_NOT_OK((*dst)->Zero());
|
||||
} else if (tensor_type == DataType::DE_INT8) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_BOOL) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT8) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
|
||||
} else if (tensor_type == DataType::DE_UINT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
|
||||
}
|
||||
std::vector<dsize_t> cur_ind(src->Rank(), 0), src_s(src->Rank(), 1), dst_s(src->Rank(), 1);
|
||||
for (dsize_t i = src->Rank() - 2; i >= 0; i--) {
|
||||
src_s[i] = src->shape()[i + 1] * src_s[i + 1];
|
||||
dst_s[i] = pad_shape[i + 1] * dst_s[i + 1];
|
||||
}
|
||||
RETURN_IF_NOT_OK(PadHelper(src, *dst, cur_ind, src_s, dst_s, 0));
|
||||
}
|
||||
return Status::OK();
|
||||
} // namespace dataset
|
||||
|
||||
Status BatchOp::PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table_pair); // placeholder for now, might need this in the future
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(table_pair->first->front().size() == column_name_id_map_.size(),
|
||||
"col_name_map mismatch");
|
||||
std::vector<float> pad_vals(column_name_id_map_.size(), 0); // value to pad each column's tensor with, default 0
|
||||
Status BatchOp::PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table); // placeholder for now, might need this in the future
|
||||
CHECK_FAIL_RETURN_UNEXPECTED((*table)->front().size() == column_name_id_map.size(), "col_name_map mismatch");
|
||||
std::vector<std::shared_ptr<Tensor>> pad_vals(column_name_id_map.size(),
|
||||
0); // value to pad each column's tensor with, default 0
|
||||
std::set<int32_t> pad_cols;
|
||||
// padded_shape provided by user, maximum shapes of current batch of tensors
|
||||
std::vector<std::vector<dsize_t>> pad_shapes(column_name_id_map_.size()), max_shapes(column_name_id_map_.size());
|
||||
RETURN_IF_NOT_OK(UnpackPadInfo(&pad_cols, &pad_vals, &pad_shapes));
|
||||
std::vector<std::vector<dsize_t>> pad_shapes(column_name_id_map.size()), max_shapes(column_name_id_map.size());
|
||||
RETURN_IF_NOT_OK(UnpackPadInfo(pad_info, column_name_id_map, &pad_cols, &pad_vals, &pad_shapes));
|
||||
|
||||
// init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well
|
||||
for (size_t col_id : pad_cols) {
|
||||
max_shapes[col_id] = std::vector<dsize_t>(table_pair->first->front()[col_id]->Rank(), -1);
|
||||
max_shapes[col_id] = std::vector<dsize_t>((*table)->front()[col_id]->Rank(), -1);
|
||||
if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape");
|
||||
}
|
||||
|
||||
// calculate maximum shape for each column that needs to be padded
|
||||
for (const TensorRow &row : *(table_pair->first)) { // iterator each row in a batch
|
||||
for (size_t col_id : pad_cols) { // iterator each tensor in a row
|
||||
for (const TensorRow &row : **table) { // iterator each row in a batch
|
||||
for (size_t col_id : pad_cols) { // iterator each tensor in a row
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(),
|
||||
"Tensor to be padded together need to have the same rank");
|
||||
for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension
|
||||
|
@ -417,27 +375,29 @@ Status BatchOp::PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
|
|||
}
|
||||
|
||||
// call pad on each tensor that needs to be padded
|
||||
for (TensorRow &row : *(table_pair->first)) {
|
||||
for (TensorRow &row : **table) {
|
||||
for (size_t col_id : pad_cols) {
|
||||
std::shared_ptr<Tensor> pad_tensor;
|
||||
RETURN_IF_NOT_OK(PadTensor(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id]));
|
||||
RETURN_IF_NOT_OK(PadEnd(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id]));
|
||||
row[col_id] = pad_tensor;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::UnpackPadInfo(std::set<int32_t> *pad_cols, std::vector<float> *pad_vals,
|
||||
Status BatchOp::UnpackPadInfo(const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
|
||||
std::vector<std::vector<dsize_t>> *pad_shapes) {
|
||||
if (pad_info_.empty()) { // if pad_info empty, pad every columns automatically
|
||||
for (dsize_t col_id = 0; col_id < column_name_id_map_.size(); col_id++) {
|
||||
if (pad_info.empty()) { // if pad_info empty, pad every columns automatically
|
||||
for (dsize_t col_id = 0; col_id < column_name_id_map.size(); col_id++) {
|
||||
pad_cols->insert(col_id);
|
||||
}
|
||||
} else {
|
||||
for (auto p : pad_info_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(column_name_id_map_.find(p.first) != column_name_id_map_.end(),
|
||||
"no column exists with name:" + p.first);
|
||||
dsize_t col_id = static_cast<dsize_t>(column_name_id_map_[p.first]);
|
||||
for (const auto &p : pad_info) {
|
||||
auto location = column_name_id_map.find(p.first);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(location != column_name_id_map.end(), "no column exists with name:" + p.first);
|
||||
auto col_id = static_cast<dsize_t>(location->second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound");
|
||||
pad_cols->insert(col_id);
|
||||
(*pad_vals)[col_id] = p.second.second; // set pad values
|
||||
|
@ -447,29 +407,6 @@ Status BatchOp::UnpackPadInfo(std::set<int32_t> *pad_cols, std::vector<float> *p
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> dst, std::vector<dsize_t> cur_ind,
|
||||
const std::vector<dsize_t> &src_s, const std::vector<dsize_t> &dst_s, size_t cur_dim) {
|
||||
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
|
||||
uint8_t type_size = src->type().SizeInBytes();
|
||||
size_t len = std::min(src->shape()[cur_dim], dst->shape()[cur_dim]) * type_size;
|
||||
dsize_t src_flat_ind = 0, dst_flat_ind = 0;
|
||||
for (size_t i = 0; i < src->Rank(); i++) {
|
||||
src_flat_ind += src_s[i] * cur_ind[i];
|
||||
dst_flat_ind += dst_s[i] * cur_ind[i];
|
||||
}
|
||||
unsigned char *src_addr = src->GetMutableBuffer() + src_flat_ind * type_size;
|
||||
unsigned char *dst_addr = dst->GetMutableBuffer() + dst_flat_ind * type_size;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
|
||||
} else { // not the last dimension, keep doing recursion
|
||||
dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
|
||||
for (dsize_t i = 0; i < min_ind; i++) {
|
||||
cur_ind[cur_dim] = i;
|
||||
RETURN_IF_NOT_OK(PadHelper(src, dst, cur_ind, src_s, dst_s, cur_dim + 1));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status BatchOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -38,6 +38,7 @@ class DataBuffer;
|
|||
|
||||
using TensorBatch = std::vector<std::shared_ptr<Tensor>>;
|
||||
using TensorBatchTable = std::vector<TensorBatch>;
|
||||
using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>;
|
||||
|
||||
class BatchOp : public ParallelOp {
|
||||
public:
|
||||
|
@ -66,7 +67,7 @@ class BatchOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetPaddingMap(const std::map<std::string, std::pair<TensorShape, float>> &pad_map, bool pad = true) {
|
||||
Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) {
|
||||
builder_pad_ = pad;
|
||||
builder_pad_map_ = pad_map;
|
||||
return *this;
|
||||
|
@ -119,7 +120,7 @@ class BatchOp : public ParallelOp {
|
|||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
std::vector<std::string> builder_cols_to_map_;
|
||||
std::map<std::string, std::pair<TensorShape, float>> builder_pad_map_;
|
||||
PadInfo builder_pad_map_;
|
||||
py::function builder_batch_size_func_;
|
||||
py::function builder_batch_map_func_;
|
||||
};
|
||||
|
@ -150,8 +151,7 @@ class BatchOp : public ParallelOp {
|
|||
// @param int32_t rows_per_buf
|
||||
// @param int32_t num_workers
|
||||
BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &, py::function batch_size_func, py::function batch_map_func,
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_map);
|
||||
const std::vector<std::string> &, py::function batch_size_func, py::function batch_map_func, PadInfo pad_map);
|
||||
|
||||
// BatchOp destructor
|
||||
~BatchOp() {}
|
||||
|
@ -183,15 +183,6 @@ class BatchOp : public ParallelOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Pad input tensor according pad_shape, need to have same rank.
|
||||
// @param std::shared_ptr<Tensor> src - tensor to pad from
|
||||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param float pad_val - value to pad with
|
||||
// @return - The error code return
|
||||
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||
float pad_val);
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
|
@ -199,18 +190,6 @@ class BatchOp : public ParallelOp {
|
|||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
|
||||
// it is only meant to be called by PadTensor.
|
||||
// @tparam T - type of tensor and fill value
|
||||
// @param std::shared_ptr<Tensor> src - Tensor to pad from
|
||||
// @param std::shared_ptr<Tensor>* dst - Tensor to pad to, return value
|
||||
// @param std::vector<dsize_t> cur_ind - recursion helper
|
||||
// @param T pad_val - value to pad tensor with
|
||||
// @param size_t cur_dim - recursion helper
|
||||
// @return Status - The error code return
|
||||
Status PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> dst, std::vector<dsize_t> cur_ind,
|
||||
const std::vector<dsize_t> &src_s, const std::vector<dsize_t> &dst_s, size_t cur_dim = 0);
|
||||
|
||||
// Worker thread for doing the memcpy of batch
|
||||
// @param int32_t param workerId
|
||||
// @return Status - The error code return
|
||||
|
@ -225,23 +204,33 @@ class BatchOp : public ParallelOp {
|
|||
// @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
|
||||
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
|
||||
// @param int32_t size - batch_size
|
||||
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
|
||||
// @return Status - The error code return
|
||||
Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, size_t size);
|
||||
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
|
||||
dsize_t batch_size);
|
||||
|
||||
// Function that calls pyfunc to perform map on batch
|
||||
// @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
|
||||
// @return Status - The error code return
|
||||
Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
|
||||
|
||||
// @param const PadInfo &pad_info pad info to unpack
|
||||
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
|
||||
// @param std::set<int32_t> *cols, col ids to perform pad on
|
||||
// @param std::vector<float> *vals, default padding value for each column
|
||||
// @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
|
||||
// @return Status - The error code return
|
||||
Status UnpackPadInfo(std::set<int32_t> *cols, std::vector<float> *vals, std::vector<std::vector<dsize_t>> *shapes);
|
||||
static Status UnpackPadInfo(const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
|
||||
std::vector<std::vector<dsize_t>> *pad_shapes);
|
||||
|
||||
// @param table_pair
|
||||
// @param table
|
||||
// @param const PadInfo &pad_info pad info
|
||||
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
|
||||
// @return Status - The error code return
|
||||
Status PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
|
||||
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map);
|
||||
|
||||
// the number of thread pulling from the mOutConnector of the Op below
|
||||
// @return int32_t, 1
|
||||
|
@ -264,11 +253,11 @@ class BatchOp : public ParallelOp {
|
|||
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
|
||||
|
||||
int32_t start_batch_size_;
|
||||
bool drop_; // bool for whether to drop remainder or not
|
||||
bool pad_; // bool for whether to perform padding on tensor
|
||||
std::vector<std::string> pyfunc_column_names_; // Name of the columns to perform map op on
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_info_; // column names to perform padding on
|
||||
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
|
||||
bool drop_; // bool for whether to drop remainder or not
|
||||
bool pad_; // bool for whether to perform padding on tensor
|
||||
std::vector<std::string> pyfunc_column_names_; // Name of the columns to perform map op on
|
||||
PadInfo pad_info_; // column names to perform padding on
|
||||
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
|
||||
QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_; // internal queue for syncing worker
|
||||
py::function batch_size_func_; // Function pointer of batch size function
|
||||
py::function batch_map_func_; // Function pointer of per batch map function
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
|
@ -220,5 +222,125 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||
const std::shared_ptr<Tensor> &pad_val) {
|
||||
if (pad_val == nullptr) {
|
||||
if (src->type().IsNumeric()) {
|
||||
return PadEndNumeric(src, dst, pad_shape, 0);
|
||||
} else {
|
||||
return PadEndString(src, dst, pad_shape, "");
|
||||
}
|
||||
}
|
||||
if (pad_val->type().IsNumeric()) {
|
||||
float val = 0;
|
||||
RETURN_IF_NOT_OK(pad_val->GetItemAt<float>(&val, {}));
|
||||
return PadEndNumeric(src, dst, pad_shape, val);
|
||||
}
|
||||
std::string_view val;
|
||||
RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {}));
|
||||
return PadEndString(src, dst, pad_shape, std::string(val));
|
||||
}
|
||||
|
||||
Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, float pad_val) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
|
||||
if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
|
||||
(*dst) = src; // if no padding, copy the pointer
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
|
||||
auto tensor_type = src->type().value();
|
||||
if (pad_val == 0) { // if pad with zero, don't care what type it is
|
||||
RETURN_IF_NOT_OK((*dst)->Zero());
|
||||
} else if (tensor_type == DataType::DE_INT8) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_BOOL) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT8) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
|
||||
} else if (tensor_type == DataType::DE_UINT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
|
||||
}
|
||||
std::vector<dsize_t> cur_ind(src->Rank(), 0);
|
||||
RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
|
||||
std::vector<dsize_t> cur_ind, size_t cur_dim) {
|
||||
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
|
||||
dst->CopyLastDimAt(src, cur_ind);
|
||||
} else { // not the last dimension, keep doing recursion
|
||||
dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
|
||||
for (dsize_t i = 0; i < min_ind; i++) {
|
||||
cur_ind[cur_dim] = i;
|
||||
RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, const std::string &pad_val) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
|
||||
if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
|
||||
(*dst) = src; // if no padding, copy the pointer
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
|
||||
std::vector<dsize_t> cur_ind(src->Rank(), 0);
|
||||
std::vector<std::string> strings;
|
||||
RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
|
||||
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
|
||||
const std::string &pad_value) {
|
||||
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
|
||||
dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
|
||||
for (dsize_t i = 0; i < min_ind; i++) {
|
||||
cur_ind[cur_dim] = i;
|
||||
std::string_view item;
|
||||
RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind));
|
||||
dst->emplace_back(item);
|
||||
}
|
||||
for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) {
|
||||
dst->emplace_back(pad_value);
|
||||
}
|
||||
|
||||
} else { // not the last dimension, keep doing recursion
|
||||
dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
|
||||
for (dsize_t i = 0; i < min_ind; i++) {
|
||||
cur_ind[cur_dim] = i;
|
||||
RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value));
|
||||
}
|
||||
dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim];
|
||||
for (dsize_t i = 0; i < count; i++) {
|
||||
dst->emplace_back(pad_value);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define DATASET_KERNELS_DATA_DATA_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/cv_tensor.h"
|
||||
|
@ -58,6 +59,59 @@ void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output)
|
|||
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type);
|
||||
|
||||
// Pad input tensor according pad_shape, need to have same rank.
|
||||
// Based on the type of the input tensor, PadEndNumeric/String will be called.
|
||||
// @param std::shared_ptr<Tensor> src - tensor to pad from
|
||||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format,
|
||||
// @return - The error code return
|
||||
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||
const std::shared_ptr<Tensor> &pad_val);
|
||||
|
||||
// Pad input numeric tensor according pad_shape, need to have same rank.
|
||||
// @param std::shared_ptr<Tensor> src - tensor to pad from
|
||||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param float pad_val - value to pad with
|
||||
// @return - The error code return
|
||||
Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, float pad_val);
|
||||
|
||||
// recursive helper function for padding numric tensors. This function could be very expensive if called on a
|
||||
// multi-dimensional tensor it is only meant to be called by PadEndNumeric.
|
||||
// @tparam T - type of tensor and fill value
|
||||
// @param std::shared_ptr<Tensor> src - Tensor to pad from
|
||||
// @param std::shared_ptr<Tensor>* dst - Tensor to pad to, return value
|
||||
// @param std::vector<dsize_t> cur_ind - recursion helper
|
||||
// @param T pad_val - value to pad tensor with
|
||||
// @param size_t cur_dim - recursion helper
|
||||
// @return Status - The error code return
|
||||
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
|
||||
std::vector<dsize_t> cur_ind, size_t cur_dim = 0);
|
||||
|
||||
// Pad input string tensor according pad_shape, need to have same rank.
|
||||
// @param std::shared_ptr<Tensor> src - tensor to pad from
|
||||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param std::string pad_val - value to pad with
|
||||
// @return - The error code return
|
||||
Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, const std::string &pad_val);
|
||||
|
||||
// recursive helper function for padding string tensors. This function could be very expensive if called on a
|
||||
// multi-dimensional tensor it is only meant to be called by PadEndNumeric.
|
||||
// @tparam T - type of tensor and fill value
|
||||
// @param std::shared_ptr<Tensor> src - Tensor to pad from
|
||||
// @param std::shared_ptr<Tensor>* dst - Tensor to pad to, return value
|
||||
// @param std::vector<dsize_t> cur_ind - recursion helper
|
||||
// @param std::string pad_val - value to pad tensor with
|
||||
// @param size_t cur_dim - recursion helper
|
||||
// @return Status - The error code return
|
||||
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
|
||||
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
|
||||
const std::string &pad_value);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,10 +19,12 @@ import inspect as ins
|
|||
import os
|
||||
from functools import wraps
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import numpy as np
|
||||
from mindspore._c_expression import typing
|
||||
from . import samplers
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
||||
INT32_MAX = 2147483647
|
||||
valid_detype = [
|
||||
|
@ -683,7 +685,7 @@ def check_pad_info(key, val):
|
|||
check_type(dim, "dim in pad_shape", int)
|
||||
assert dim > 0, "pad shape should be positive integers"
|
||||
if val[1] is not None:
|
||||
check_type(val[1], "pad_value", (int, float))
|
||||
check_type(val[1], "pad_value", (int, float, str, bytes))
|
||||
|
||||
|
||||
def check_batch(method):
|
||||
|
|
|
@ -299,8 +299,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
|||
TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::shared_ptr<BatchOp> op;
|
||||
std::map<std::string, std::pair<TensorShape, float>> m;
|
||||
m.insert({"col_1d", std::make_pair(TensorShape({4}), -1)});
|
||||
PadInfo m;
|
||||
std::shared_ptr<Tensor> pad_value;
|
||||
Tensor::CreateTensor(&pad_value, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32));
|
||||
pad_value->SetItemAt<float>({}, -1);
|
||||
m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)});
|
||||
de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op);
|
||||
auto tree = Build({Storage(schema_file), op});
|
||||
tree->Prepare();
|
||||
|
@ -308,9 +311,54 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
|
|||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
|
||||
} else {
|
||||
int64_t payload[] = {-9223372036854775807 - 1, 1, -1, -1, 2, 3, -1, -1, 4, 5, -1, -1, 6, 7, -1, -1,
|
||||
8, 9, -1, -1, 10, 11, -1, -1, 12, 13, -1, -1, 14, 15, -1, -1,
|
||||
16, 17, -1, -1, 18, 19, -1, -1, 20, 21, -1, -1, 22, 23, -1, -1};
|
||||
int64_t payload[] = {-9223372036854775807 - 1,
|
||||
1,
|
||||
-1,
|
||||
-1,
|
||||
2,
|
||||
3,
|
||||
-1,
|
||||
-1,
|
||||
4,
|
||||
5,
|
||||
-1,
|
||||
-1,
|
||||
6,
|
||||
7,
|
||||
-1,
|
||||
-1,
|
||||
8,
|
||||
9,
|
||||
-1,
|
||||
-1,
|
||||
10,
|
||||
11,
|
||||
-1,
|
||||
-1,
|
||||
12,
|
||||
13,
|
||||
-1,
|
||||
-1,
|
||||
14,
|
||||
15,
|
||||
-1,
|
||||
-1,
|
||||
16,
|
||||
17,
|
||||
-1,
|
||||
-1,
|
||||
18,
|
||||
19,
|
||||
-1,
|
||||
-1,
|
||||
20,
|
||||
21,
|
||||
-1,
|
||||
-1,
|
||||
22,
|
||||
23,
|
||||
-1,
|
||||
-1};
|
||||
std::shared_ptr<de::Tensor> t;
|
||||
rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
|
|
|
@ -12,15 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore._c_dataengine as cde
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore._c_dataengine as cde
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.dataset.text import to_str
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.text import to_str, to_bytes
|
||||
|
||||
|
||||
# pylint: disable=comparison-with-itself
|
||||
def test_basic():
|
||||
x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
|
||||
n = cde.Tensor(x)
|
||||
|
@ -28,8 +27,8 @@ def test_basic():
|
|||
np.testing.assert_array_equal(x, arr)
|
||||
|
||||
|
||||
def compare(strings):
|
||||
arr = np.array(strings, dtype='S')
|
||||
def compare(strings, dtype='S'):
|
||||
arr = np.array(strings, dtype=dtype)
|
||||
|
||||
def gen():
|
||||
(yield arr,)
|
||||
|
@ -37,25 +36,51 @@ def compare(strings):
|
|||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
|
||||
for d in data:
|
||||
np.testing.assert_array_equal(d[0], arr)
|
||||
np.testing.assert_array_equal(d[0], arr.astype('S'))
|
||||
|
||||
|
||||
def test_generator():
|
||||
compare(["ab"])
|
||||
compare(["", ""])
|
||||
compare([""])
|
||||
compare(["ab", ""])
|
||||
compare(["ab", "cde", "121"])
|
||||
compare([["ab", "cde", "121"], ["x", "km", "789"]])
|
||||
compare([["ab", "", "121"], ["", "km", "789"]])
|
||||
compare(["ab"], dtype='U')
|
||||
compare(["", ""], dtype='U')
|
||||
compare([""], dtype='U')
|
||||
compare(["ab", ""], dtype='U')
|
||||
compare(["", ""], dtype='U')
|
||||
compare(["", "ab"], dtype='U')
|
||||
compare(["ab", "cde", "121"], dtype='U')
|
||||
compare([["ab", "cde", "121"], ["x", "km", "789"]], dtype='U')
|
||||
compare([["ab", "", "121"], ["", "km", "789"]], dtype='U')
|
||||
|
||||
|
||||
line = np.array(["This is a text file.",
|
||||
"Be happy every day.",
|
||||
"Good luck to everyone."])
|
||||
|
||||
words = np.array([["This", "text", "file", "a"],
|
||||
["Be", "happy", "day", "b"],
|
||||
["女", "", "everyone", "c"]])
|
||||
|
||||
chinese = np.array(["今天天气太好了我们一起去外面玩吧",
|
||||
"男默女泪",
|
||||
"江州市长江大桥参加了长江大桥的通车仪式"])
|
||||
|
||||
|
||||
def test_batching_strings():
|
||||
def gen():
|
||||
yield (np.array(["ab", "cde", "121"], dtype='S'),)
|
||||
for row in chinese:
|
||||
yield (np.array(row),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
data = data.batch(2, drop_remainder=True)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in data:
|
||||
pass
|
||||
assert "[Batch ERROR] Batch does not support" in str(info.value)
|
||||
for d in data:
|
||||
np.testing.assert_array_equal(d[0], to_bytes(chinese[0:2]))
|
||||
|
||||
|
||||
def test_map():
|
||||
|
@ -67,7 +92,7 @@ def test_map():
|
|||
def split(b):
|
||||
s = to_str(b)
|
||||
splits = s.item().split()
|
||||
return np.array(splits, dtype='S')
|
||||
return np.array(splits)
|
||||
|
||||
data = data.map(input_columns=["col"], operations=split)
|
||||
expected = np.array(["ab", "cde", "121"], dtype='S')
|
||||
|
@ -91,19 +116,6 @@ def test_map2():
|
|||
np.testing.assert_array_equal(d[0], expected)
|
||||
|
||||
|
||||
line = np.array(["This is a text file.",
|
||||
"Be happy every day.",
|
||||
"Good luck to everyone."])
|
||||
|
||||
words = np.array([["This", "text", "file", "a"],
|
||||
["Be", "happy", "day", "b"],
|
||||
["女", "", "everyone", "c"]])
|
||||
|
||||
chinese = np.array(["今天天气太好了我们一起去外面玩吧",
|
||||
"男默女泪",
|
||||
"江州市长江大桥参加了长江大桥的通车仪式"])
|
||||
|
||||
|
||||
def test_tfrecord1():
|
||||
s = ds.Schema()
|
||||
s.add_column("line", "string", [])
|
||||
|
@ -181,6 +193,94 @@ def test_mindrecord():
|
|||
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
||||
|
||||
|
||||
# The following tests cases were copied from test_pad_batch but changed to strings instead
|
||||
|
||||
|
||||
# this generator function yield two columns
|
||||
# col1d: [0],[1], [2], [3]
|
||||
# col2d: [[100],[200]], [[101],[201]], [102],[202]], [103],[203]]
|
||||
def gen_2cols(num):
|
||||
for i in range(num):
|
||||
yield (np.array([str(i)]), np.array([[str(i + 100)], [str(i + 200)]]))
|
||||
|
||||
|
||||
# this generator function yield one column of variable shapes
|
||||
# col: [0], [0,1], [0,1,2], [0,1,2,3]
|
||||
def gen_var_col(num):
|
||||
for i in range(num):
|
||||
yield (np.array([str(j) for j in range(i + 1)]),)
|
||||
|
||||
|
||||
# this generator function yield two columns of variable shapes
|
||||
# col1: [0], [0,1], [0,1,2], [0,1,2,3]
|
||||
# col2: [100], [100,101], [100,101,102], [100,110,102,103]
|
||||
def gen_var_cols(num):
|
||||
for i in range(num):
|
||||
yield (np.array([str(j) for j in range(i + 1)]), np.array([str(100 + j) for j in range(i + 1)]))
|
||||
|
||||
|
||||
# this generator function yield two columns of variable shapes
|
||||
# col1: [[0]], [[0,1]], [[0,1,2]], [[0,1,2,3]]
|
||||
# col2: [[100]], [[100,101]], [[100,101,102]], [[100,110,102,103]]
|
||||
def gen_var_cols_2d(num):
|
||||
for i in range(num):
|
||||
yield (np.array([[str(j) for j in range(i + 1)]]), np.array([[str(100 + j) for j in range(i + 1)]]))
|
||||
|
||||
|
||||
def test_batch_padding_01():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], b"-2"), "col1d": ([2], b"-1")})
|
||||
data1 = data1.repeat(2)
|
||||
for data in data1.create_dict_iterator():
|
||||
np.testing.assert_array_equal([[b"0", b"-1"], [b"1", b"-1"]], data["col1d"])
|
||||
np.testing.assert_array_equal([[[b"100", b"-2"], [b"200", b"-2"]], [[b"101", b"-2"], [b"201", b"-2"]]],
|
||||
data["col2d"])
|
||||
|
||||
|
||||
def test_batch_padding_02():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], "")})
|
||||
data1 = data1.repeat(2)
|
||||
for data in data1.create_dict_iterator():
|
||||
np.testing.assert_array_equal([[b"0"], [b"1"]], data["col1d"])
|
||||
np.testing.assert_array_equal([[[b"100", b""]], [[b"101", b""]]], data["col2d"])
|
||||
|
||||
|
||||
def test_batch_padding_03():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_var_col(4)), ["col"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col": (None, "PAD_VALUE")}) # pad automatically
|
||||
data1 = data1.repeat(2)
|
||||
res = dict()
|
||||
for ind, data in enumerate(data1.create_dict_iterator()):
|
||||
res[ind] = data["col"].copy()
|
||||
np.testing.assert_array_equal(res[0], [[b"0", b"PAD_VALUE"], [0, 1]])
|
||||
np.testing.assert_array_equal(res[1], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]])
|
||||
np.testing.assert_array_equal(res[2], [[b"0", b"PAD_VALUE"], [b"0", b"1"]])
|
||||
np.testing.assert_array_equal(res[3], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]])
|
||||
|
||||
|
||||
def test_batch_padding_04():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_var_cols(2)), ["col1", "col2"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={}) # pad automatically
|
||||
data1 = data1.repeat(2)
|
||||
for data in data1.create_dict_iterator():
|
||||
np.testing.assert_array_equal(data["col1"], [[b"0", b""], [b"0", b"1"]])
|
||||
np.testing.assert_array_equal(data["col2"], [[b"100", b""], [b"100", b"101"]])
|
||||
|
||||
|
||||
def test_batch_padding_05():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_var_cols_2d(3)), ["col1", "col2"])
|
||||
data1 = data1.batch(batch_size=3, drop_remainder=False,
|
||||
pad_info={"col2": ([2, None], "-2"), "col1": (None, "-1")}) # pad automatically
|
||||
for data in data1.create_dict_iterator():
|
||||
np.testing.assert_array_equal(data["col1"],
|
||||
[[[b"0", b"-1", b"-1"]], [[b"0", b"1", b"-1"]], [[b"0", b"1", b"2"]]])
|
||||
np.testing.assert_array_equal(data["col2"],
|
||||
[[[b"100", b"-2", b"-2"], [b"-2", b"-2", b"-2"]],
|
||||
[[b"100", b"101", b"-2"], [b"-2", b"-2", b"-2"]],
|
||||
[[b"100", b"101", b"102"], [b"-2", b"-2", b"-2"]]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_generator()
|
||||
test_basic()
|
||||
|
@ -191,3 +291,8 @@ if __name__ == '__main__':
|
|||
test_tfrecord2()
|
||||
test_tfrecord3()
|
||||
test_mindrecord()
|
||||
test_batch_padding_01()
|
||||
test_batch_padding_02()
|
||||
test_batch_padding_03()
|
||||
test_batch_padding_04()
|
||||
test_batch_padding_05()
|
||||
|
|
Loading…
Reference in New Issue