diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 0871b3f30cd..04b1264cc85 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -100,6 +100,10 @@ Status CsvOp::Init() { int CsvOp::CsvParser::put_record(char c) { std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_); std::shared_ptr t; + if (cur_col_ >= column_default_.size()) { + err_message_ = "Number of file columns does not match the default records"; + return -1; + } switch (column_default_[cur_col_]->type) { case CsvOp::INT: Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)); @@ -116,6 +120,10 @@ int CsvOp::CsvParser::put_record(char c) { Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); break; } + if (cur_col_ >= (*tensor_table_)[cur_row_].size()) { + err_message_ = "Number of file columns does not match the tensor table"; + return -1; + } (*tensor_table_)[cur_row_][cur_col_] = std::move(t); pos_ = 0; cur_col_++; @@ -134,7 +142,11 @@ int CsvOp::CsvParser::put_row(char c) { return 0; } - put_record(c); + int ret = put_record(c); + if (ret < 0) { + return ret; + } + total_rows_++; cur_row_++; cur_col_ = 0; @@ -265,8 +277,7 @@ Status CsvOp::CsvParser::initCsvParser() { [this](CsvParser &, char c) -> int { this->tensor_table_ = std::make_unique(); this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); - this->put_record(c); - return 0; + return this->put_record(c); }}}, {{State::START_OF_FILE, Message::MS_QUOTE}, {State::QUOTE, @@ -367,8 +378,7 @@ Status CsvOp::CsvParser::initCsvParser() { if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) { this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); } - this->put_record(c); - return 0; + return this->put_record(c); }}}, {{State::END_OF_LINE, Message::MS_QUOTE}, {State::QUOTE, @@ -408,15 +418,16 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons while (ifs.good()) { char chr = ifs.get(); if (csv_parser.processMessage(chr) != 0) { - RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_)); + RETURN_STATUS_UNEXPECTED("Failed to parse file " + file + ":" + std::to_string(csv_parser.total_rows_ + 1) + + ". error message: " + csv_parser.err_message_); } } } catch (std::invalid_argument &ia) { - std::string err_row = std::to_string(csv_parser.total_rows_); - RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what())); + std::string err_row = std::to_string(csv_parser.total_rows_ + 1); + RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", type does not match"); } catch (std::out_of_range &oor) { - std::string err_row = std::to_string(csv_parser.total_rows_); - RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what())); + std::string err_row = std::to_string(csv_parser.total_rows_ + 1); + RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of range"); } return Status::OK(); } @@ -763,6 +774,9 @@ Status CsvOp::ComputeColMap() { column_default_list_.push_back(std::make_shared>(CsvOp::STRING, "")); } } + if (column_default_list_.size() != column_name_id_map_.size()) { + RETURN_STATUS_UNEXPECTED("The number of column names does not match the column defaults"); + } return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index a456549e756..b095542756b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -76,7 +76,8 @@ class CsvOp : public ParallelOp { cur_col_(0), total_rows_(0), start_offset_(0), - end_offset_(std::numeric_limits::max()) { + end_offset_(std::numeric_limits::max()), + err_message_("unkonw") { cur_buffer_ = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); initCsvParser(); } @@ -189,6 +190,7 @@ class CsvOp : public ParallelOp { std::vector str_buf_; std::unique_ptr tensor_table_; std::unique_ptr cur_buffer_; + std::string err_message_; }; class Builder { diff --git a/tests/ut/python/dataset/test_datasets_csv.py b/tests/ut/python/dataset/test_datasets_csv.py index 021bbe942fb..f998e9774db 100644 --- a/tests/ut/python/dataset/test_datasets_csv.py +++ b/tests/ut/python/dataset/test_datasets_csv.py @@ -205,7 +205,7 @@ def test_csv_dataset_exception(): with pytest.raises(Exception) as err: for _ in data.create_dict_iterator(): pass - assert "Failed to parse CSV file" in str(err.value) + assert "Failed to parse file" in str(err.value) def test_csv_dataset_type_error(): @@ -218,7 +218,7 @@ def test_csv_dataset_type_error(): with pytest.raises(Exception) as err: for _ in data.create_dict_iterator(): pass - assert "invalid argument of stoi" in str(err.value) + assert "type does not match" in str(err.value) if __name__ == "__main__":