forked from mindspore-Ecosystem/mindspore
batch with padding implemented
support for 1 specific dimension to be None, added validator fix various CI complains another round of CI fixes ci refactor parts of the code code refactor ci fix comments added, fix bugs address review comments address review comments review cmts added simple perf test script update pad code perf imprv
This commit is contained in:
parent
16930c562d
commit
c2d364a573
|
@ -207,6 +207,8 @@ int DEPipeline::GetBatchSize() const { return batch_size_; }
|
|||
|
||||
int DEPipeline::GetRepeatCount() const { return repeat_num_; }
|
||||
|
||||
float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
|
||||
|
||||
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
||||
|
@ -621,6 +623,21 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
if (key == "input_columns") {
|
||||
(void)builder->SetColumnsToMap(ToStringVector(value));
|
||||
}
|
||||
if (key == "pad_info") {
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_info;
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
py::tuple tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
float pad_val = tp[1].is_none() ? 0 : ToFloat(tp[1]);
|
||||
(void)pad_info.insert({ToString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info.insert({ToString(p.first), {TensorShape({}), 0}});
|
||||
}
|
||||
}
|
||||
(void)builder->SetPaddingMap(pad_info, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -93,10 +93,10 @@ class Tensor {
|
|||
|
||||
// Copy raw data of a array based on shape and strides to the destination pointer
|
||||
// @param dst Pointer to the destination array where the content is to be copied
|
||||
// @param src Pointer to the source of stided array to be copied
|
||||
// @param src Pointer to the source of strided array to be copied
|
||||
// @param shape - shape of the source array
|
||||
// @param strides - strides of the source array
|
||||
// @param type_size - number of bytes needed to store one array elment's type
|
||||
// @param type_size - number of bytes needed to store one array element's type
|
||||
// @return Status Code
|
||||
static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector<dsize_t> shape,
|
||||
std::vector<dsize_t> strides, uint8_t type_size);
|
||||
|
@ -138,10 +138,10 @@ class Tensor {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// fill tensor with Zeros
|
||||
Status Zero() {
|
||||
dsize_t size = SizeInBytes();
|
||||
int retCode = memset_sp(StartAddr(), size, 0, size);
|
||||
if (retCode != 0) return Status(StatusCode::kUnexpectedError, "Failed to fill tensor with zeroes.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(StartAddr(), size, 0, size) == 0, "Failed to fill tensor with zeroes.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -154,10 +154,7 @@ class Tensor {
|
|||
int64_t cellSize = type_.SizeInBytes();
|
||||
if ((data_ != nullptr) && type_.IsCompatible<T>()) {
|
||||
for (dsize_t i = 0; i < Size(); i++) {
|
||||
int retCode = memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize);
|
||||
if (retCode != 0) {
|
||||
return Status(StatusCode::kUnexpectedError, "Failed to fill tensor.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err");
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
|
|
|
@ -87,8 +87,12 @@ TensorShape::TensorShape(const TensorShape &shape) : raw_shape_(*GlobalContext::
|
|||
|
||||
TensorShape::TensorShape(py::list l) : raw_shape_(*GlobalContext::Instance()->int_allocator()) {
|
||||
std::vector<dsize_t> list_c;
|
||||
for (auto i : l) {
|
||||
list_c.push_back(i.cast<int>());
|
||||
for (auto &i : l) {
|
||||
if (!i.is_none()) {
|
||||
list_c.push_back(i.cast<int>());
|
||||
} else {
|
||||
list_c.push_back(TensorShape::kDimUnknown);
|
||||
}
|
||||
}
|
||||
AddListToShape(list_c);
|
||||
}
|
||||
|
|
|
@ -65,6 +65,10 @@ class TensorShape {
|
|||
// @param shape
|
||||
TensorShape(const TensorShape &shape);
|
||||
|
||||
// construct a TensorShape via a python list
|
||||
// @param py::list l - a list object from python
|
||||
explicit TensorShape(py::list l);
|
||||
|
||||
~TensorShape() = default;
|
||||
|
||||
// Create a scalar Shape (i.e., empty shape with mKnown = true)
|
||||
|
@ -142,8 +146,6 @@ class TensorShape {
|
|||
return out;
|
||||
}
|
||||
|
||||
explicit TensorShape(py::list l);
|
||||
|
||||
py::list AsPyList();
|
||||
|
||||
// Checks if the given index is a valid index for this tensor.
|
||||
|
|
|
@ -14,15 +14,20 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/batch_op.h"
|
||||
|
||||
#include <utility>
|
||||
#include <iomanip>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "dataset/core/pybind_support.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
|
||||
using float16 = Eigen::half;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false) {
|
||||
BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) {
|
||||
builder_batch_size_ = batch_size;
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
|
@ -31,8 +36,9 @@ BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false) {
|
|||
|
||||
Status BatchOp::Builder::Build(std::shared_ptr<BatchOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<BatchOp>(builder_batch_size_, builder_drop_, builder_op_connector_size_, builder_num_workers_,
|
||||
builder_cols_to_map_, builder_batch_size_func_, builder_batch_map_func_);
|
||||
*ptr = std::make_shared<BatchOp>(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_,
|
||||
builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_,
|
||||
builder_batch_map_func_, builder_pad_map_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -44,14 +50,17 @@ Status BatchOp::Builder::SanityCheck() {
|
|||
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
|
||||
}
|
||||
|
||||
BatchOp::BatchOp(int32_t batch_size, bool drop, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &cols_to_map, py::function batch_size_func, py::function batch_map_func)
|
||||
BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &cols_to_map, py::function batch_size_func, py::function batch_map_func,
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_map)
|
||||
: ParallelOp(num_workers, op_queue_size),
|
||||
start_batch_size_(batch_size),
|
||||
drop_(drop),
|
||||
input_column_names_(cols_to_map),
|
||||
pad_(pad),
|
||||
pyfunc_column_names_(cols_to_map),
|
||||
batch_size_func_(batch_size_func),
|
||||
batch_map_func_(batch_map_func) {
|
||||
batch_map_func_(batch_map_func),
|
||||
pad_info_(pad_map) {
|
||||
worker_queues_.Init(num_workers, op_queue_size);
|
||||
}
|
||||
|
||||
|
@ -181,7 +190,8 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
|
|||
Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
|
||||
std::unique_ptr<DataBuffer> *db) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table_pair.first);
|
||||
if (!input_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
|
||||
if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
|
||||
if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair)); // do padding if needed
|
||||
(*db) = std::make_unique<DataBuffer>(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone);
|
||||
std::unique_ptr<TensorQTable> dest_table = std::make_unique<TensorQTable>();
|
||||
RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size()));
|
||||
|
@ -206,8 +216,8 @@ Status BatchOp::EoeReceived(int32_t) {
|
|||
|
||||
Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair) {
|
||||
TensorBatchTable input_table;
|
||||
input_table.reserve(input_column_names_.size());
|
||||
for (std::string col_name : input_column_names_) {
|
||||
input_table.reserve(pyfunc_column_names_.size());
|
||||
for (std::string col_name : pyfunc_column_names_) {
|
||||
if (column_name_map_.find(col_name) == column_name_map_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n");
|
||||
}
|
||||
|
@ -225,8 +235,8 @@ Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>
|
|||
RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second));
|
||||
|
||||
// Write back to TensorQTable
|
||||
for (size_t input_idx = 0; input_idx < input_column_names_.size(); input_idx++) {
|
||||
size_t col_idx = static_cast<size_t>(column_name_map_[input_column_names_[input_idx]]);
|
||||
for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) {
|
||||
size_t col_idx = static_cast<size_t>(column_name_map_[pyfunc_column_names_[input_idx]]);
|
||||
size_t row_id = 0;
|
||||
for (TensorRow &row : *(table_pair->first)) {
|
||||
row[col_idx] = std::move(output_table[input_idx][row_id++]);
|
||||
|
@ -290,8 +300,8 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou
|
|||
py::object ret_py_obj = batch_map_func_(*input_args);
|
||||
// Parse batch map return value
|
||||
py::tuple ret_tuple = py::cast<py::tuple>(ret_py_obj);
|
||||
if (ret_tuple.size() != input_column_names_.size() || !py::isinstance<py::tuple>(ret_tuple)) {
|
||||
return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple if size(input_columns)");
|
||||
if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance<py::tuple>(ret_tuple)) {
|
||||
return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple");
|
||||
}
|
||||
for (size_t i = 0; i < ret_tuple.size(); i++) {
|
||||
TensorBatch output_batch;
|
||||
|
@ -311,5 +321,142 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou
|
|||
}
|
||||
return Status(StatusCode::kOK);
|
||||
}
|
||||
|
||||
Status BatchOp::PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, float pad_val) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
|
||||
if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
|
||||
(*dst) = src; // if no padding, copy the pointer
|
||||
} else {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
|
||||
auto tensor_type = src->type().value();
|
||||
if (pad_val == 0) { // if pad with zero, don't care what type it is
|
||||
RETURN_IF_NOT_OK((*dst)->Zero());
|
||||
} else if (tensor_type == DataType::DE_INT8) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_BOOL) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT8) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
|
||||
} else if (tensor_type == DataType::DE_UINT16) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_INT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_UINT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT32) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
|
||||
} else if (tensor_type == DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
|
||||
}
|
||||
std::vector<dsize_t> cur_ind(src->Rank(), 0), src_s(src->Rank(), 1), dst_s(src->Rank(), 1);
|
||||
for (dsize_t i = src->Rank() - 2; i >= 0; i--) {
|
||||
src_s[i] = src->shape()[i + 1] * src_s[i + 1];
|
||||
dst_s[i] = pad_shape[i + 1] * dst_s[i + 1];
|
||||
}
|
||||
RETURN_IF_NOT_OK(PadHelper(src, *dst, cur_ind, src_s, dst_s, 0));
|
||||
}
|
||||
return Status::OK();
|
||||
} // namespace dataset
|
||||
|
||||
Status BatchOp::PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair) {
|
||||
RETURN_UNEXPECTED_IF_NULL(table_pair); // placeholder for now, might need this in the future
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(table_pair->first->front().size() == column_name_map_.size(), "col_name_map mismatch");
|
||||
std::vector<float> pad_vals(column_name_map_.size(), 0); // value to pad each column's tensor with, default 0
|
||||
std::set<int32_t> pad_cols;
|
||||
// padded_shape provided by user, maximum shapes of current batch of tensors
|
||||
std::vector<std::vector<dsize_t>> pad_shapes(column_name_map_.size()), max_shapes(column_name_map_.size());
|
||||
RETURN_IF_NOT_OK(UnpackPadInfo(&pad_cols, &pad_vals, &pad_shapes));
|
||||
|
||||
// init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well
|
||||
for (size_t col_id : pad_cols) {
|
||||
max_shapes[col_id] = std::vector<dsize_t>(table_pair->first->front()[col_id]->Rank(), -1);
|
||||
if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape");
|
||||
}
|
||||
|
||||
// calculate maximum shape for each column that needs to be padded
|
||||
for (const TensorRow &row : *(table_pair->first)) { // iterator each row in a batch
|
||||
for (size_t col_id : pad_cols) { // iterator each tensor in a row
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(),
|
||||
"Tensor to be padded together need to have the same rank");
|
||||
for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension
|
||||
max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if user sets a dimension to -1 (None in python), use the max value for current dimension
|
||||
for (size_t col_id : pad_cols) {
|
||||
for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) {
|
||||
if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim];
|
||||
}
|
||||
}
|
||||
|
||||
// call pad on each tensor that needs to be padded
|
||||
for (TensorRow &row : *(table_pair->first)) {
|
||||
for (size_t col_id : pad_cols) {
|
||||
std::shared_ptr<Tensor> pad_tensor;
|
||||
RETURN_IF_NOT_OK(PadTensor(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id]));
|
||||
row[col_id] = pad_tensor;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::UnpackPadInfo(std::set<int32_t> *pad_cols, std::vector<float> *pad_vals,
|
||||
std::vector<std::vector<dsize_t>> *pad_shapes) {
|
||||
if (pad_info_.empty()) { // if pad_info empty, pad every columns automatically
|
||||
for (dsize_t col_id = 0; col_id < column_name_map_.size(); col_id++) {
|
||||
pad_cols->insert(col_id);
|
||||
}
|
||||
} else {
|
||||
for (auto p : pad_info_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(column_name_map_.find(p.first) != column_name_map_.end(),
|
||||
"no column exists with name:" + p.first);
|
||||
dsize_t col_id = static_cast<dsize_t>(column_name_map_[p.first]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound");
|
||||
pad_cols->insert(col_id);
|
||||
(*pad_vals)[col_id] = p.second.second; // set pad values
|
||||
(*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> dst, std::vector<dsize_t> cur_ind,
|
||||
const std::vector<dsize_t> &src_s, const std::vector<dsize_t> &dst_s, size_t cur_dim) {
|
||||
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
|
||||
uint8_t type_size = src->type().SizeInBytes();
|
||||
size_t len = std::min(src->shape()[cur_dim], dst->shape()[cur_dim]) * type_size;
|
||||
dsize_t src_flat_ind = 0, dst_flat_ind = 0;
|
||||
for (size_t i = 0; i < src->Rank(); i++) {
|
||||
src_flat_ind += src_s[i] * cur_ind[i];
|
||||
dst_flat_ind += dst_s[i] * cur_ind[i];
|
||||
}
|
||||
unsigned char *src_addr = src->StartAddr() + src_flat_ind * type_size;
|
||||
unsigned char *dst_addr = dst->StartAddr() + dst_flat_ind * type_size;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
|
||||
} else { // not the last dimension, keep doing recursion
|
||||
dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
|
||||
for (dsize_t i = 0; i < min_ind; i++) {
|
||||
cur_ind[cur_dim] = i;
|
||||
RETURN_IF_NOT_OK(PadHelper(src, dst, cur_ind, src_s, dst_s, cur_dim + 1));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,8 +16,11 @@
|
|||
#ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
@ -44,10 +47,6 @@ class BatchOp : public ParallelOp {
|
|||
// @param int32_t batch_size
|
||||
explicit Builder(int32_t batch_size);
|
||||
|
||||
// Builder constructor for Batch, batch size function needs to be specified
|
||||
// @param py::function batch_size_func
|
||||
explicit Builder(py::function batch_size_func);
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
|
@ -67,6 +66,12 @@ class BatchOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetPaddingMap(const std::map<std::string, std::pair<TensorShape, float>> &pad_map, bool pad = true) {
|
||||
builder_pad_ = pad;
|
||||
builder_pad_map_ = pad_map;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// set connector size for batch
|
||||
// @param int32_t op_conn_size
|
||||
// @return Builder & reference to builder class object
|
||||
|
@ -109,11 +114,12 @@ class BatchOp : public ParallelOp {
|
|||
Status SanityCheck();
|
||||
|
||||
bool builder_drop_;
|
||||
bool builder_pad_;
|
||||
int32_t builder_batch_size_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
std::vector<std::string> builder_cols_to_map_;
|
||||
|
||||
std::map<std::string, std::pair<TensorShape, float>> builder_pad_map_;
|
||||
py::function builder_batch_size_func_;
|
||||
py::function builder_batch_map_func_;
|
||||
};
|
||||
|
@ -143,8 +149,9 @@ class BatchOp : public ParallelOp {
|
|||
// @param int32_t op_queue_size
|
||||
// @param int32_t rows_per_buf
|
||||
// @param int32_t num_workers
|
||||
BatchOp(int32_t batch_size, bool drop, int32_t op_queue_size, int32_t num_workers, const std::vector<std::string> &,
|
||||
py::function batch_size_func, py::function batch_map_func);
|
||||
BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &, py::function batch_size_func, py::function batch_map_func,
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_map);
|
||||
|
||||
// BatchOp destructor
|
||||
~BatchOp() {}
|
||||
|
@ -176,7 +183,28 @@ class BatchOp : public ParallelOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Pad input tensor according pad_shape, need to have same rank.
|
||||
// @param std::shared_ptr<Tensor> src - tensor to pad from
|
||||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param float pad_val - value to pad with
|
||||
// @return - The error code return
|
||||
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||
float pad_val);
|
||||
|
||||
private:
|
||||
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
|
||||
// it is only meant to be called by PadTensor.
|
||||
// @tparam T - type of tensor and fill value
|
||||
// @param std::shared_ptr<Tensor> src - Tensor to pad from
|
||||
// @param std::shared_ptr<Tensor>* dst - Tensor to pad to, return value
|
||||
// @param std::vector<dsize_t> cur_ind - recursion helper
|
||||
// @param T pad_val - value to pad tensor with
|
||||
// @param size_t cur_dim - recursion helper
|
||||
// @return Status - The error code return
|
||||
Status PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> dst, std::vector<dsize_t> cur_ind,
|
||||
const std::vector<dsize_t> &src_s, const std::vector<dsize_t> &dst_s, size_t cur_dim = 0);
|
||||
|
||||
// Worker thread for doing the memcpy of batch
|
||||
// @param int32_t param workerId
|
||||
// @return Status - The error code return
|
||||
|
@ -199,6 +227,16 @@ class BatchOp : public ParallelOp {
|
|||
// @return Status - The error code return
|
||||
Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
|
||||
|
||||
// @param std::set<int32_t> *cols, col ids to perform pad on
|
||||
// @param std::vector<float> *vals, default padding value for each column
|
||||
// @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
|
||||
// @return Status - The error code return
|
||||
Status UnpackPadInfo(std::set<int32_t> *cols, std::vector<float> *vals, std::vector<std::vector<dsize_t>> *shapes);
|
||||
|
||||
// @param table_pair
|
||||
// @return Status - The error code return
|
||||
Status PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
|
||||
|
||||
// the number of thread pulling from the mOutConnector of the Op below
|
||||
// @return int32_t, 1
|
||||
int32_t num_consumers() const override { return 1; }
|
||||
|
@ -220,19 +258,15 @@ class BatchOp : public ParallelOp {
|
|||
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
|
||||
|
||||
int32_t start_batch_size_;
|
||||
bool drop_;
|
||||
// Name of the columns to perform map op on
|
||||
std::vector<std::string> input_column_names_;
|
||||
// Iterator for fetching
|
||||
std::unique_ptr<ChildIterator> child_iterator_;
|
||||
// Map of column_name: column_index
|
||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||
// Internal queue for task distribution
|
||||
QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_;
|
||||
// Function pointer of batch size function
|
||||
py::function batch_size_func_;
|
||||
// Function pointer of per batch map function
|
||||
py::function batch_map_func_;
|
||||
bool drop_; // bool for whether to drop remainder or not
|
||||
bool pad_; // bool for whether to perform padding on tensor
|
||||
std::vector<std::string> pyfunc_column_names_; // Name of the columns to perform map op on
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_info_; // column names to perform padding on
|
||||
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
|
||||
std::unordered_map<std::string, int32_t> column_name_map_; // Map of column_name: column_index
|
||||
QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_; // internal queue for syncing worker
|
||||
py::function batch_size_func_; // Function pointer of batch size function
|
||||
py::function batch_map_func_; // Function pointer of per batch map function
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,7 +40,8 @@ from mindspore._c_expression import typing
|
|||
from mindspore import log as logger
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator
|
||||
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \
|
||||
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset
|
||||
|
@ -163,7 +164,7 @@ class Dataset:
|
|||
|
||||
@check_batch
|
||||
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
|
||||
input_columns=None):
|
||||
input_columns=None, pad_info=None):
|
||||
"""
|
||||
Combines batch_size number of consecutive rows into batches.
|
||||
|
||||
|
@ -181,7 +182,7 @@ class Dataset:
|
|||
drop_remainder (bool, optional): Determines whether or not to drop the last
|
||||
possibly incomplete batch (default=False). If True, and if there are less
|
||||
than batch_size rows available to make the last batch, then those rows will
|
||||
be dropped and not propogated to the child node.
|
||||
be dropped and not propagated to the child node.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
|
||||
per_batch_map (callable, optional): Per batch map callable. A callable which takes
|
||||
(list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
|
||||
|
@ -189,6 +190,8 @@ class Dataset:
|
|||
last parameter of the callable should always be a BatchInfo object.
|
||||
input_columns (list of string, optional): List of names of the input columns. The size of the list should
|
||||
match with signature of per_batch_map callable.
|
||||
pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
|
||||
would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
|
||||
|
||||
Returns:
|
||||
BatchDataset, dataset batched.
|
||||
|
@ -200,7 +203,8 @@ class Dataset:
|
|||
>>> # and drops the last incomplete batch if there is one.
|
||||
>>> data = data.batch(100, True)
|
||||
"""
|
||||
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns)
|
||||
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
|
||||
pad_info)
|
||||
|
||||
@check_sync_wait
|
||||
def sync_wait(self, condition_name, num_batch=1, callback=None):
|
||||
|
@ -1026,13 +1030,26 @@ class BatchDataset(DatasetOp):
|
|||
|
||||
Args:
|
||||
input_dataset (Dataset): Input Dataset to be batched.
|
||||
batch_size (int): The size of the batch.
|
||||
drop_remainder (bool, optional): Whether drop the remainder batch of data (drop_remainder=False).
|
||||
If True, the last incomplete batch will be dropped.
|
||||
batch_size (int or function): The number of rows each batch is created with. An
|
||||
int or callable which takes exactly 1 parameter, BatchInfo.
|
||||
drop_remainder (bool, optional): Determines whether or not to drop the last
|
||||
possibly incomplete batch (default=False). If True, and if there are less
|
||||
than batch_size rows available to make the last batch, then those rows will
|
||||
be dropped and not propagated to the child node.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
|
||||
per_batch_map (callable, optional): Per batch map callable. A callable which takes
|
||||
(list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
|
||||
Tensors on a given column. The number of lists should match with number of entries in input_columns. The
|
||||
last parameter of the callable should always be a BatchInfo object.
|
||||
input_columns (list of string, optional): List of names of the input columns. The size of the list should
|
||||
match with signature of per_batch_map callable.
|
||||
pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
|
||||
would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
|
||||
per_batch_map=None, input_columns=None):
|
||||
per_batch_map=None, input_columns=None, pad_info=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
if BatchDataset._is_ancestor_of_repeat(input_dataset):
|
||||
|
@ -1044,6 +1061,7 @@ class BatchDataset(DatasetOp):
|
|||
self.drop_remainder = drop_remainder
|
||||
self.per_batch_map = per_batch_map
|
||||
self.input_columns = input_columns
|
||||
self.pad_info = pad_info
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
@ -1054,6 +1072,7 @@ class BatchDataset(DatasetOp):
|
|||
args["drop_remainder"] = self.drop_remainder
|
||||
args["per_batch_map"] = self.per_batch_map
|
||||
args["input_columns"] = self.input_columns
|
||||
args["pad_info"] = self.pad_info
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -2702,6 +2721,7 @@ class TFRecordDataset(SourceDataset):
|
|||
>>> # 3) get all rows from dataset_files with schema file "./schema.json":
|
||||
>>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
|
||||
"""
|
||||
|
||||
@check_tfrecorddataset
|
||||
def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
|
||||
|
@ -3551,6 +3571,7 @@ class CelebADataset(SourceDataset):
|
|||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
||||
|
||||
class TextFileDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in text format.
|
||||
|
|
|
@ -324,6 +324,7 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
|
||||
def check_imagefolderdatasetv2(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -356,6 +357,7 @@ def check_imagefolderdatasetv2(method):
|
|||
|
||||
def check_mnist_cifar_dataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -382,6 +384,7 @@ def check_mnist_cifar_dataset(method):
|
|||
|
||||
def check_manifestdataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -414,6 +417,7 @@ def check_manifestdataset(method):
|
|||
|
||||
def check_tfrecorddataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -444,6 +448,7 @@ def check_tfrecorddataset(method):
|
|||
|
||||
def check_vocdataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -470,6 +475,7 @@ def check_vocdataset(method):
|
|||
|
||||
def check_celebadataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -510,6 +516,7 @@ def check_celebadataset(method):
|
|||
|
||||
def check_minddataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -541,6 +548,7 @@ def check_minddataset(method):
|
|||
|
||||
def check_generatordataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -628,8 +636,25 @@ def check_columns(columns, name):
|
|||
raise TypeError("{} should be either a list of strings or a single string.".format(name))
|
||||
|
||||
|
||||
def check_pad_info(key, val):
|
||||
"""check the key and value pair of pad_info in batch"""
|
||||
check_type(key, "key in pad_info", str)
|
||||
if val is not None:
|
||||
assert len(val) == 2, "value of pad_info should be a tuple of size 2"
|
||||
check_type(val, "value in pad_info", tuple)
|
||||
if val[0] is not None:
|
||||
check_type(val[0], "pad_shape", list)
|
||||
for dim in val[0]:
|
||||
if dim is not None:
|
||||
check_type(dim, "dim in pad_shape", int)
|
||||
assert dim > 0, "pad shape should be positive integers"
|
||||
if val[1] is not None:
|
||||
check_type(val[1], "pad_value", (int, float))
|
||||
|
||||
|
||||
def check_batch(method):
|
||||
"""check the input arguments of batch."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -648,6 +673,14 @@ def check_batch(method):
|
|||
|
||||
check_param_type(nreq_param_bool, param_dict, bool)
|
||||
|
||||
if (param_dict.get('pad_info') is not None) and (param_dict.get('per_batch_map') is not None):
|
||||
raise ValueError("pad_info and per_batch_map can't both be set")
|
||||
|
||||
if param_dict.get('pad_info') is not None:
|
||||
check_type(param_dict["pad_info"], "pad_info", dict)
|
||||
for k, v in param_dict.get('pad_info').items():
|
||||
check_pad_info(k, v)
|
||||
|
||||
for param_name in nreq_param_columns:
|
||||
param = param_dict.get(param_name)
|
||||
if param is not None:
|
||||
|
@ -687,6 +720,7 @@ def check_sync_wait(method):
|
|||
|
||||
def check_shuffle(method):
|
||||
"""check the input arguments of shuffle."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -705,6 +739,7 @@ def check_shuffle(method):
|
|||
|
||||
def check_map(method):
|
||||
"""check the input arguments of map."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -729,6 +764,7 @@ def check_map(method):
|
|||
|
||||
def check_filter(method):
|
||||
""""check the input arguments of filter."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -749,6 +785,7 @@ def check_filter(method):
|
|||
|
||||
def check_repeat(method):
|
||||
"""check the input arguments of repeat."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -764,6 +801,7 @@ def check_repeat(method):
|
|||
|
||||
def check_skip(method):
|
||||
"""check the input arguments of skip."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -780,6 +818,7 @@ def check_skip(method):
|
|||
|
||||
def check_take(method):
|
||||
"""check the input arguments of take."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -794,6 +833,7 @@ def check_take(method):
|
|||
|
||||
def check_zip(method):
|
||||
"""check the input arguments of zip."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -811,6 +851,7 @@ def check_zip(method):
|
|||
|
||||
def check_zip_dataset(method):
|
||||
"""check the input arguments of zip method in `Dataset`."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -830,6 +871,7 @@ def check_zip_dataset(method):
|
|||
|
||||
def check_rename(method):
|
||||
"""check the input arguments of rename."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -849,6 +891,7 @@ def check_rename(method):
|
|||
|
||||
def check_project(method):
|
||||
"""check the input arguments of project."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -876,6 +919,7 @@ def check_shape(shape, name):
|
|||
|
||||
def check_add_column(method):
|
||||
"""check the input arguments of add_column."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -905,6 +949,7 @@ def check_add_column(method):
|
|||
|
||||
def check_textfiledataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
|
|
@ -30,16 +30,14 @@ namespace common = mindspore::common;
|
|||
namespace de = mindspore::dataset;
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::ERROR;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::ERROR;
|
||||
|
||||
class MindDataTestBatchOp : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
||||
};
|
||||
|
||||
|
||||
std::shared_ptr<de::BatchOp> Batch(int32_t batch_size = 1, bool drop = false, int rows_per_buf = 2) {
|
||||
Status rc;
|
||||
std::shared_ptr<de::BatchOp> op;
|
||||
|
@ -93,10 +91,8 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
|
|||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<de::Tensor> t;
|
||||
rc = de::Tensor::CreateTensor(&t,
|
||||
TensorImpl::kFlexible, de::TensorShape({12, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) payload);
|
||||
rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// verify the actual data in Tensor is correct
|
||||
EXPECT_EQ(*t == *tensor_map["col_sint64"], true);
|
||||
|
@ -111,7 +107,6 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
|
|||
EXPECT_EQ(success, true);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
bool success = false;
|
||||
|
@ -125,20 +120,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
|||
-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
|
||||
de::DatasetIterator di(tree);
|
||||
std::shared_ptr<de::Tensor> t1, t2, t3;
|
||||
rc = de::Tensor::CreateTensor(&t1,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) payload);
|
||||
rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t2,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 7));
|
||||
rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 7));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t3,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 2));
|
||||
rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 2));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
TensorMap tensor_map;
|
||||
|
@ -163,7 +152,6 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
|||
EXPECT_EQ(success, true);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
bool success = false;
|
||||
|
@ -177,25 +165,17 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
|||
-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
|
||||
de::DatasetIterator di(tree);
|
||||
std::shared_ptr<de::Tensor> t1, t2, t3, t4;
|
||||
rc = de::Tensor::CreateTensor(&t1,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) payload);
|
||||
rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t2,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 7));
|
||||
rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 7));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t3,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 2));
|
||||
rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 2));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t4,
|
||||
TensorImpl::kFlexible, de::TensorShape({3, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 9));
|
||||
rc = de::Tensor::CreateTensor(&t4, TensorImpl::kFlexible, de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 9));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
TensorMap tensor_map;
|
||||
|
@ -224,7 +204,6 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
|||
EXPECT_EQ(success, true);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
bool success = false;
|
||||
|
@ -238,15 +217,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
|||
-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
|
||||
de::DatasetIterator di(tree);
|
||||
std::shared_ptr<de::Tensor> t1, t2;
|
||||
rc = de::Tensor::CreateTensor(&t1,
|
||||
TensorImpl::kFlexible, de::TensorShape({7, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) payload);
|
||||
rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t2,
|
||||
TensorImpl::kFlexible, de::TensorShape({5, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 7));
|
||||
rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 7));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
TensorMap tensor_map;
|
||||
|
@ -275,7 +250,6 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
|||
EXPECT_EQ(success, true);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
bool success = false;
|
||||
|
@ -289,15 +263,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
|||
-9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807};
|
||||
de::DatasetIterator di(tree);
|
||||
std::shared_ptr<de::Tensor> t1, t2;
|
||||
rc = de::Tensor::CreateTensor(&t1,
|
||||
TensorImpl::kFlexible, de::TensorShape({5, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) payload);
|
||||
rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = de::Tensor::CreateTensor(&t2,
|
||||
TensorImpl::kFlexible, de::TensorShape({5, 1}),
|
||||
de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *) (payload + 5));
|
||||
rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)(payload + 5));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
TensorMap tensor_map;
|
||||
|
@ -325,3 +295,31 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
|||
}
|
||||
EXPECT_EQ(success, true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::shared_ptr<BatchOp> op;
|
||||
std::map<std::string, std::pair<TensorShape, float>> m;
|
||||
m.insert({"col_1d", std::make_pair(TensorShape({4}), -1)});
|
||||
de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op);
|
||||
auto tree = Build({Storage(schema_file), op});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
|
||||
} else {
|
||||
int64_t payload[] = {-9223372036854775807 - 1, 1, -1, -1, 2, 3, -1, -1, 4, 5, -1, -1, 6, 7, -1, -1,
|
||||
8, 9, -1, -1, 10, 11, -1, -1, 12, 13, -1, -1, 14, 15, -1, -1,
|
||||
16, 17, -1, -1, 18, 19, -1, -1, 20, 21, -1, -1, 22, 23, -1, -1};
|
||||
std::shared_ptr<de::Tensor> t;
|
||||
rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64),
|
||||
(unsigned char *)payload);
|
||||
de::DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE((*t) == (*(tensor_map["col_1d"])));
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(tensor_map.size() == 0);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
# This UT test tests the following cases
|
||||
|
||||
# 1. padding: input_shape=[x] output_shape=[y] where y > x
|
||||
# 2. padding in one dimension and truncate in the other. input_shape=[x1,x2] output_shape=[y1,y2] y1>x1 and y2<x2
|
||||
# 3. automatic padding for a specific column
|
||||
# 4. default setting for all columns
|
||||
# 5. test None in different places
|
||||
|
||||
# this generator function yield two columns
|
||||
# col1d: [0],[1], [2], [3]
|
||||
# col2d: [[100],[200]], [[101],[201]], [102],[202]], [103],[203]]
|
||||
def gen_2cols(num):
|
||||
for i in range(num):
|
||||
yield (np.array([i]), np.array([[i + 100], [i + 200]]))
|
||||
|
||||
|
||||
# this generator function yield one column of variable shapes
|
||||
# col: [0], [0,1], [0,1,2], [0,1,2,3]
|
||||
def gen_var_col(num):
|
||||
for i in range(num):
|
||||
yield (np.array([j for j in range(i + 1)]),)
|
||||
|
||||
|
||||
# this generator function yield two columns of variable shapes
|
||||
# col1: [0], [0,1], [0,1,2], [0,1,2,3]
|
||||
# col2: [100], [100,101], [100,101,102], [100,110,102,103]
|
||||
def gen_var_cols(num):
|
||||
for i in range(num):
|
||||
yield (np.array([j for j in range(i + 1)]), np.array([100 + j for j in range(i + 1)]))
|
||||
|
||||
|
||||
# this generator function yield two columns of variable shapes
|
||||
# col1: [[0]], [[0,1]], [[0,1,2]], [[0,1,2,3]]
|
||||
# col2: [[100]], [[100,101]], [[100,101,102]], [[100,110,102,103]]
|
||||
def gen_var_cols_2d(num):
|
||||
for i in range(num):
|
||||
yield (np.array([[j for j in range(i + 1)]]), np.array([[100 + j for j in range(i + 1)]]))
|
||||
|
||||
|
||||
def test_batch_padding_01():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], -2), "col1d": ([2], -1)})
|
||||
data1 = data1.repeat(2)
|
||||
for data in data1.create_dict_iterator():
|
||||
assert np.array_equal([[0, -1], [1, -1]], data["col1d"])
|
||||
assert np.array_equal([[[100, -2], [200, -2]], [[101, -2], [201, -2]]], data["col2d"])
|
||||
|
||||
|
||||
def test_batch_padding_02():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], -2)})
|
||||
data1 = data1.repeat(2)
|
||||
for data in data1.create_dict_iterator():
|
||||
assert np.array_equal([[0], [1]], data["col1d"])
|
||||
assert np.array_equal([[[100, -2]], [[101, -2]]], data["col2d"])
|
||||
|
||||
|
||||
def test_batch_padding_03():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_var_col(4)), ["col"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col": (None, -1)}) # pad automatically
|
||||
data1 = data1.repeat(2)
|
||||
res = dict()
|
||||
for ind, data in enumerate(data1.create_dict_iterator()):
|
||||
res[ind] = data["col"].copy()
|
||||
assert np.array_equal(res[0], [[0, -1], [0, 1]])
|
||||
assert np.array_equal(res[1], [[0, 1, 2, -1], [0, 1, 2, 3]])
|
||||
assert np.array_equal(res[2], [[0, -1], [0, 1]])
|
||||
assert np.array_equal(res[3], [[0, 1, 2, -1], [0, 1, 2, 3]])
|
||||
|
||||
|
||||
def test_batch_padding_04():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_var_cols(2)), ["col1", "col2"])
|
||||
data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={}) # pad automatically
|
||||
data1 = data1.repeat(2)
|
||||
for data in data1.create_dict_iterator():
|
||||
assert np.array_equal(data["col1"], [[0, 0], [0, 1]])
|
||||
assert np.array_equal(data["col2"], [[100, 0], [100, 101]])
|
||||
|
||||
|
||||
def test_batch_padding_05():
|
||||
data1 = ds.GeneratorDataset((lambda: gen_var_cols_2d(3)), ["col1", "col2"])
|
||||
data1 = data1.batch(batch_size=3, drop_remainder=False,
|
||||
pad_info={"col2": ([2, None], -2), "col1": (None, -1)}) # pad automatically
|
||||
for data in data1.create_dict_iterator():
|
||||
assert np.array_equal(data["col1"], [[[0, -1, -1]], [[0, 1, -1]], [[0, 1, 2]]])
|
||||
assert np.array_equal(data["col2"], [[[100, -2, -2], [-2, -2, -2]], [[100, 101, -2], [-2, -2, -2]],
|
||||
[[100, 101, 102], [-2, -2, -2]]])
|
||||
|
||||
|
||||
def batch_padding_performance_3d():
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
data1 = data1.repeat(24)
|
||||
pad_info = {"image": ([36, 36, 3], 0)}
|
||||
# pad_info = None
|
||||
data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
ret = []
|
||||
for data in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
def batch_padding_performance_1d():
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
data1 = data1.repeat(24)
|
||||
data1 = data1.map(input_columns="image", operations=(lambda x: x.reshape(-1)))
|
||||
pad_info = {"image": ([3888], 0)} # 3888 =36*36*3
|
||||
# pad_info = None
|
||||
data1 = data1.batch(batch_size=24, drop_remainder=True, pad_info=pad_info)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
for data in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
def batch_pyfunc_padding_3d():
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
data1 = data1.repeat(24)
|
||||
# pad_info = {"image": ([36, 36, 3], 0)}
|
||||
data1 = data1.map(input_columns="image", operations=(lambda x: np.pad(x, ((0, 4), (0, 4), (0, 0)))),
|
||||
python_multiprocessing=False)
|
||||
data1 = data1.batch(batch_size=24, drop_remainder=True)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
for data in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
def batch_pyfunc_padding_1d():
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
data1 = data1.repeat(24)
|
||||
data1 = data1.map(input_columns="image", operations=(lambda x: x.reshape(-1)))
|
||||
data1 = data1.map(input_columns="image", operations=(lambda x: np.pad(x, (0, 816))), python_multiprocessing=False)
|
||||
data1 = data1.batch(batch_size=24, drop_remainder=True)
|
||||
start_time = time.time()
|
||||
num_batches = 0
|
||||
for data in data1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time)
|
||||
# print(res)
|
||||
|
||||
|
||||
# this function runs pad_batch and numpy.pad then compare the results
|
||||
def test_pad_via_map():
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
|
||||
def pad_map_config():
|
||||
data1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False, num_samples=1000) # shape = [32,32,3]
|
||||
data1 = data1.map(input_columns="image", operations=(lambda x: x.reshape(-1))) # reshape to 1d
|
||||
data1 = data1.map(input_columns="image", operations=(lambda x: np.pad(x, (0, 816))))
|
||||
data1 = data1.batch(batch_size=25, drop_remainder=True)
|
||||
res = []
|
||||
for data in data1.create_dict_iterator():
|
||||
res.append(data["image"])
|
||||
return res
|
||||
|
||||
def pad_batch_config():
|
||||
data2 = ds.Cifar10Dataset(cifar10_dir, shuffle=False, num_samples=1000) # shape = [32,32,3]
|
||||
data2 = data2.map(input_columns="image", operations=(lambda x: x.reshape(-1))) # reshape to 1d
|
||||
data2 = data2.batch(batch_size=25, drop_remainder=True, pad_info={"image": ([3888], 0)})
|
||||
res = []
|
||||
for data in data2.create_dict_iterator():
|
||||
res.append(data["image"])
|
||||
return res
|
||||
|
||||
res_from_map = pad_map_config()
|
||||
res_from_batch = pad_batch_config()
|
||||
assert len(res_from_batch) == len(res_from_batch)
|
||||
for i in range(len(res_from_map)):
|
||||
assert np.array_equal(res_from_map[i], res_from_batch[i])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch_padding_01()
|
||||
test_batch_padding_02()
|
||||
test_batch_padding_03()
|
||||
test_batch_padding_04()
|
||||
test_batch_padding_05()
|
||||
# batch_padding_performance_3d()
|
||||
# batch_padding_performance_1d()
|
||||
# batch_pyfunc_padding_3d()
|
||||
# batch_pyfunc_padding_1d()
|
||||
test_pad_via_map()
|
Loading…
Reference in New Issue