dataset Python Pushdown - more minor code rework; CSVDataset delim_field bug fix

This commit is contained in:
Cathy Wong 2020-12-03 17:26:50 -05:00
parent 73c91e05b1
commit 69b269c6cf
5 changed files with 58 additions and 31 deletions

View File

@ -191,8 +191,6 @@ Status TreeAdapter::GetNext(TensorRow *row) {
Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
if (s.IsOk()) {
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
}
if (tracing_ != nullptr) {
cur_connector_size_ = tree_->root()->ConnectorSize();
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
}

View File

@ -382,30 +382,30 @@ class SchemaObj {
Status Init();
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId).
/// \return bool true if schema init success
/// \param[in] name Name of the column.
/// \param[in] de_type Data type of the column(TypeId).
/// \return Status code
Status add_column(const std::string &name, TypeId de_type);
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
/// \param[in] name Name of the column.
/// \param[in] de_type Data type of the column(std::string).
/// \param[in] shape Shape of the column.
/// \return Status code
Status add_column(const std::string &name, const std::string &de_type);
/// \brief Add new column to the schema
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
/// \param[in] name Name of the column.
/// \param[in] de_type Data type of the column(TypeId).
/// \param[in] shape Shape of the column.
/// \return Status code
Status add_column(const std::string &name, TypeId de_type, const std::vector<int32_t> &shape);
/// \brief Add new column to the schema
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
/// \param[in] name Name of the column.
/// \param[in] de_type Data type of the column(std::string).
/// \param[in] shape Shape of the column.
/// \return Status code
Status add_column(const std::string &name, const std::string &de_type, const std::vector<int32_t> &shape);
/// \brief Get a JSON string of the schema
@ -415,29 +415,35 @@ class SchemaObj {
/// \brief Get a JSON string of the schema
std::string to_string() { return to_json(); }
/// \brief set a new value to dataset_type
/// \brief Set a new value to dataset_type
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); }
/// \brief set a new value to num_rows
/// \brief Set a new value to num_rows
inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
/// \brief get the current num_rows
/// \brief Get the current num_rows
inline int32_t get_num_rows() const { return num_rows_; }
/// \brief Get schema file from JSON file
/// \param[in] json_string Name of JSON file to be parsed.
/// \return Status code
Status FromJSONString(const std::string &json_string);
/// \brief Parse and add column information
/// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
/// \return Status code
Status ParseColumnString(const std::string &json_string);
private:
/// \brief Parse the columns and add it to columns
/// \param[in] columns dataset attribution information, decoded from schema file.
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
/// \return JSON string of the schema
/// \brief Parse the columns and add them to columns
/// \param[in] columns Dataset attribution information, decoded from schema file.
/// Support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
/// \return Status code
Status parse_column(nlohmann::json columns);
/// \brief Get schema file from json file
/// \param[in] json_obj object of json parsed.
/// \return bool true if json dump success
/// \brief Get schema file from JSON file
/// \param[in] json_obj Object of JSON parsed.
/// \return Status code
Status from_json(nlohmann::json json_obj);
int32_t num_rows_;

View File

@ -5357,7 +5357,7 @@ class CSVDataset(SourceDataset):
super().__init__(num_parallel_workers=num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.field_delim = replace_none(field_delim, '')
self.field_delim = replace_none(field_delim, ',')
self.column_defaults = replace_none(column_defaults, [])
self.column_names = replace_none(column_names, [])
self.num_samples = num_samples

View File

@ -924,9 +924,10 @@ def check_csvdataset(method):
# check field_delim
field_delim = param_dict.get('field_delim')
type_check(field_delim, (str,), 'field delim')
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
raise ValueError("field_delim is invalid.")
if field_delim is not None:
type_check(field_delim, (str,), 'field delim')
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
raise ValueError("field_delim is invalid.")
# check column_defaults
column_defaults = param_dict.get('column_defaults')

View File

@ -28,6 +28,7 @@ def test_csv_dataset_basic():
buffer = []
data = ds.CSVDataset(
TRAIN_FILE,
field_delim=',',
column_defaults=["0", 0, 0.0, "0"],
column_names=['1', '2', '3', '4'],
shuffle=False)
@ -185,6 +186,26 @@ def test_csv_dataset_number():
assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
def test_csv_dataset_field_delim_none():
"""
Test CSV with field_delim=None
"""
TRAIN_FILE = '../data/dataset/testCSV/1.csv'
buffer = []
data = ds.CSVDataset(
TRAIN_FILE,
field_delim=None,
column_defaults=["0", 0, 0.0, "0"],
column_names=['1', '2', '3', '4'],
shuffle=False)
data = data.repeat(2)
data = data.skip(2)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 4
def test_csv_dataset_size():
TEST_FILE = '../data/dataset/testCSV/size.csv'
data = ds.CSVDataset(
@ -245,6 +266,7 @@ if __name__ == "__main__":
test_csv_dataset_chinese()
test_csv_dataset_header()
test_csv_dataset_number()
test_csv_dataset_field_delim_none()
test_csv_dataset_size()
test_csv_dataset_type_error()
test_csv_dataset_exception()