forked from mindspore-Ecosystem/mindspore
fix number of columns not match
This commit is contained in:
parent
623d884e0c
commit
2cc6b5cb52
|
@ -100,6 +100,10 @@ Status CsvOp::Init() {
|
||||||
int CsvOp::CsvParser::put_record(char c) {
|
int CsvOp::CsvParser::put_record(char c) {
|
||||||
std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_);
|
std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_);
|
||||||
std::shared_ptr<Tensor> t;
|
std::shared_ptr<Tensor> t;
|
||||||
|
if (cur_col_ >= column_default_.size()) {
|
||||||
|
err_message_ = "Number of file columns does not match the default records";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
switch (column_default_[cur_col_]->type) {
|
switch (column_default_[cur_col_]->type) {
|
||||||
case CsvOp::INT:
|
case CsvOp::INT:
|
||||||
Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32));
|
Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32));
|
||||||
|
@ -116,6 +120,10 @@ int CsvOp::CsvParser::put_record(char c) {
|
||||||
Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar());
|
Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (cur_col_ >= (*tensor_table_)[cur_row_].size()) {
|
||||||
|
err_message_ = "Number of file columns does not match the tensor table";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
(*tensor_table_)[cur_row_][cur_col_] = std::move(t);
|
(*tensor_table_)[cur_row_][cur_col_] = std::move(t);
|
||||||
pos_ = 0;
|
pos_ = 0;
|
||||||
cur_col_++;
|
cur_col_++;
|
||||||
|
@ -134,7 +142,11 @@ int CsvOp::CsvParser::put_row(char c) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
put_record(c);
|
int ret = put_record(c);
|
||||||
|
if (ret < 0) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
total_rows_++;
|
total_rows_++;
|
||||||
cur_row_++;
|
cur_row_++;
|
||||||
cur_col_ = 0;
|
cur_col_ = 0;
|
||||||
|
@ -265,8 +277,7 @@ Status CsvOp::CsvParser::initCsvParser() {
|
||||||
[this](CsvParser &, char c) -> int {
|
[this](CsvParser &, char c) -> int {
|
||||||
this->tensor_table_ = std::make_unique<TensorQTable>();
|
this->tensor_table_ = std::make_unique<TensorQTable>();
|
||||||
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
|
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
|
||||||
this->put_record(c);
|
return this->put_record(c);
|
||||||
return 0;
|
|
||||||
}}},
|
}}},
|
||||||
{{State::START_OF_FILE, Message::MS_QUOTE},
|
{{State::START_OF_FILE, Message::MS_QUOTE},
|
||||||
{State::QUOTE,
|
{State::QUOTE,
|
||||||
|
@ -367,8 +378,7 @@ Status CsvOp::CsvParser::initCsvParser() {
|
||||||
if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) {
|
if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) {
|
||||||
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
|
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
|
||||||
}
|
}
|
||||||
this->put_record(c);
|
return this->put_record(c);
|
||||||
return 0;
|
|
||||||
}}},
|
}}},
|
||||||
{{State::END_OF_LINE, Message::MS_QUOTE},
|
{{State::END_OF_LINE, Message::MS_QUOTE},
|
||||||
{State::QUOTE,
|
{State::QUOTE,
|
||||||
|
@ -408,15 +418,16 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons
|
||||||
while (ifs.good()) {
|
while (ifs.good()) {
|
||||||
char chr = ifs.get();
|
char chr = ifs.get();
|
||||||
if (csv_parser.processMessage(chr) != 0) {
|
if (csv_parser.processMessage(chr) != 0) {
|
||||||
RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_));
|
RETURN_STATUS_UNEXPECTED("Failed to parse file " + file + ":" + std::to_string(csv_parser.total_rows_ + 1) +
|
||||||
|
". error message: " + csv_parser.err_message_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (std::invalid_argument &ia) {
|
} catch (std::invalid_argument &ia) {
|
||||||
std::string err_row = std::to_string(csv_parser.total_rows_);
|
std::string err_row = std::to_string(csv_parser.total_rows_ + 1);
|
||||||
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what()));
|
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", type does not match");
|
||||||
} catch (std::out_of_range &oor) {
|
} catch (std::out_of_range &oor) {
|
||||||
std::string err_row = std::to_string(csv_parser.total_rows_);
|
std::string err_row = std::to_string(csv_parser.total_rows_ + 1);
|
||||||
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what()));
|
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of range");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -763,6 +774,9 @@ Status CsvOp::ComputeColMap() {
|
||||||
column_default_list_.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ""));
|
column_default_list_.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ""));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (column_default_list_.size() != column_name_id_map_.size()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("The number of column names does not match the column defaults");
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -76,7 +76,8 @@ class CsvOp : public ParallelOp {
|
||||||
cur_col_(0),
|
cur_col_(0),
|
||||||
total_rows_(0),
|
total_rows_(0),
|
||||||
start_offset_(0),
|
start_offset_(0),
|
||||||
end_offset_(std::numeric_limits<int64_t>::max()) {
|
end_offset_(std::numeric_limits<int64_t>::max()),
|
||||||
|
err_message_("unkonw") {
|
||||||
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
||||||
initCsvParser();
|
initCsvParser();
|
||||||
}
|
}
|
||||||
|
@ -189,6 +190,7 @@ class CsvOp : public ParallelOp {
|
||||||
std::vector<char> str_buf_;
|
std::vector<char> str_buf_;
|
||||||
std::unique_ptr<TensorQTable> tensor_table_;
|
std::unique_ptr<TensorQTable> tensor_table_;
|
||||||
std::unique_ptr<DataBuffer> cur_buffer_;
|
std::unique_ptr<DataBuffer> cur_buffer_;
|
||||||
|
std::string err_message_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Builder {
|
class Builder {
|
||||||
|
|
|
@ -205,7 +205,7 @@ def test_csv_dataset_exception():
|
||||||
with pytest.raises(Exception) as err:
|
with pytest.raises(Exception) as err:
|
||||||
for _ in data.create_dict_iterator():
|
for _ in data.create_dict_iterator():
|
||||||
pass
|
pass
|
||||||
assert "Failed to parse CSV file" in str(err.value)
|
assert "Failed to parse file" in str(err.value)
|
||||||
|
|
||||||
|
|
||||||
def test_csv_dataset_type_error():
|
def test_csv_dataset_type_error():
|
||||||
|
@ -218,7 +218,7 @@ def test_csv_dataset_type_error():
|
||||||
with pytest.raises(Exception) as err:
|
with pytest.raises(Exception) as err:
|
||||||
for _ in data.create_dict_iterator():
|
for _ in data.create_dict_iterator():
|
||||||
pass
|
pass
|
||||||
assert "invalid argument of stoi" in str(err.value)
|
assert "type does not match" in str(err.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue