checkou colums names for bucket_batch_by_length

This commit is contained in:
yanghaitao 2020-09-11 10:35:06 +08:00
parent c21baba879
commit 92eb17d982
2 changed files with 16 additions and 0 deletions

View File

@ -224,5 +224,19 @@ Status BucketBatchByLengthOp::Reset() {
return Status::OK();
}
// Computing the assignment of the column name map and check compute input columns.
Status BucketBatchByLengthOp::ComputeColMap() {
RETURN_IF_NOT_OK(DatasetOp::ComputeColMap());
for (const auto &inCol : length_dependent_columns_) {
bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false;
if (!found) {
std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -135,6 +135,8 @@ class BucketBatchByLengthOp : public PipelineOp {
Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size);
Status ComputeColMap() override;
std::vector<std::string> length_dependent_columns_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;