From 92eb17d9827bfa6919bb0cd542a6492cbeb1c5b0 Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Fri, 11 Sep 2020 10:35:06 +0800 Subject: [PATCH] checkou colums names for bucket_batch_by_length --- .../engine/datasetops/bucket_batch_by_length_op.cc | 14 ++++++++++++++ .../engine/datasetops/bucket_batch_by_length_op.h | 2 ++ 2 files changed, 16 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc index 30da571089e..06bcdd09e73 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -224,5 +224,19 @@ Status BucketBatchByLengthOp::Reset() { return Status::OK(); } + +// Computing the assignment of the column name map and check compute input columns. +Status BucketBatchByLengthOp::ComputeColMap() { + RETURN_IF_NOT_OK(DatasetOp::ComputeColMap()); + + for (const auto &inCol : length_dependent_columns_) { + bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h index fac40a79552..21fc55e2635 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -135,6 +135,8 @@ class BucketBatchByLengthOp : public PipelineOp { Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); + Status ComputeColMap() override; + std::vector length_dependent_columns_; std::vector bucket_boundaries_; std::vector bucket_batch_sizes_;