From 69b269c6cfaed92710d6c8cf972744c250d01aa7 Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Thu, 3 Dec 2020 17:26:50 -0500 Subject: [PATCH] dataset Python Pushdown - more minor code rework; CSVDataset delim_field bug fix --- .../minddata/dataset/engine/tree_adapter.cc | 2 - .../ccsrc/minddata/dataset/include/datasets.h | 56 ++++++++++--------- mindspore/dataset/engine/datasets.py | 2 +- mindspore/dataset/engine/validators.py | 7 ++- tests/ut/python/dataset/test_datasets_csv.py | 22 ++++++++ 5 files changed, 58 insertions(+), 31 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index f1713134970..45464f88d0c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -191,8 +191,6 @@ Status TreeAdapter::GetNext(TensorRow *row) { Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); if (s.IsOk()) { tracing_ = std::dynamic_pointer_cast(node); - } - if (tracing_ != nullptr) { cur_connector_size_ = tree_->root()->ConnectorSize(); cur_connector_capacity_ = tree_->root()->ConnectorCapacity(); } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index a8b96b93938..f1d334934c5 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -382,30 +382,30 @@ class SchemaObj { Status Init(); /// \brief Add new column to the schema with unknown shape of rank 1 - /// \param[in] name name of the column. - /// \param[in] de_type data type of the column(TypeId). - /// \return bool true if schema init success + /// \param[in] name Name of the column. + /// \param[in] de_type Data type of the column(TypeId). + /// \return Status code Status add_column(const std::string &name, TypeId de_type); /// \brief Add new column to the schema with unknown shape of rank 1 - /// \param[in] name name of the column. - /// \param[in] de_type data type of the column(std::string). - /// \param[in] shape shape of the column. - /// \return bool true if schema init success + /// \param[in] name Name of the column. + /// \param[in] de_type Data type of the column(std::string). + /// \param[in] shape Shape of the column. + /// \return Status code Status add_column(const std::string &name, const std::string &de_type); /// \brief Add new column to the schema - /// \param[in] name name of the column. - /// \param[in] de_type data type of the column(TypeId). - /// \param[in] shape shape of the column. - /// \return bool true if schema init success + /// \param[in] name Name of the column. + /// \param[in] de_type Data type of the column(TypeId). + /// \param[in] shape Shape of the column. + /// \return Status code Status add_column(const std::string &name, TypeId de_type, const std::vector &shape); /// \brief Add new column to the schema - /// \param[in] name name of the column. - /// \param[in] de_type data type of the column(std::string). - /// \param[in] shape shape of the column. - /// \return bool true if schema init success + /// \param[in] name Name of the column. + /// \param[in] de_type Data type of the column(std::string). + /// \param[in] shape Shape of the column. + /// \return Status code Status add_column(const std::string &name, const std::string &de_type, const std::vector &shape); /// \brief Get a JSON string of the schema @@ -415,29 +415,35 @@ class SchemaObj { /// \brief Get a JSON string of the schema std::string to_string() { return to_json(); } - /// \brief set a new value to dataset_type + /// \brief Set a new value to dataset_type inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); } - /// \brief set a new value to num_rows + /// \brief Set a new value to num_rows inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; } - /// \brief get the current num_rows + /// \brief Get the current num_rows inline int32_t get_num_rows() const { return num_rows_; } + /// \brief Get schema file from JSON file + /// \param[in] json_string Name of JSON file to be parsed. + /// \return Status code Status FromJSONString(const std::string &json_string); + /// \brief Parse and add column information + /// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file. + /// \return Status code Status ParseColumnString(const std::string &json_string); private: - /// \brief Parse the columns and add it to columns - /// \param[in] columns dataset attribution information, decoded from schema file. - /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject. - /// \return JSON string of the schema + /// \brief Parse the columns and add them to columns + /// \param[in] columns Dataset attribution information, decoded from schema file. + /// Support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject. + /// \return Status code Status parse_column(nlohmann::json columns); - /// \brief Get schema file from json file - /// \param[in] json_obj object of json parsed. - /// \return bool true if json dump success + /// \brief Get schema file from JSON file + /// \param[in] json_obj Object of JSON parsed. + /// \return Status code Status from_json(nlohmann::json json_obj); int32_t num_rows_; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index b5725cf709d..28ba11ce327 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -5357,7 +5357,7 @@ class CSVDataset(SourceDataset): super().__init__(num_parallel_workers=num_parallel_workers) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() - self.field_delim = replace_none(field_delim, '') + self.field_delim = replace_none(field_delim, ',') self.column_defaults = replace_none(column_defaults, []) self.column_names = replace_none(column_names, []) self.num_samples = num_samples diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index cf1baccacb9..34b4f2f89d9 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -924,9 +924,10 @@ def check_csvdataset(method): # check field_delim field_delim = param_dict.get('field_delim') - type_check(field_delim, (str,), 'field delim') - if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: - raise ValueError("field_delim is invalid.") + if field_delim is not None: + type_check(field_delim, (str,), 'field delim') + if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: + raise ValueError("field_delim is invalid.") # check column_defaults column_defaults = param_dict.get('column_defaults') diff --git a/tests/ut/python/dataset/test_datasets_csv.py b/tests/ut/python/dataset/test_datasets_csv.py index cc1a0bbec41..c19f9a52385 100644 --- a/tests/ut/python/dataset/test_datasets_csv.py +++ b/tests/ut/python/dataset/test_datasets_csv.py @@ -28,6 +28,7 @@ def test_csv_dataset_basic(): buffer = [] data = ds.CSVDataset( TRAIN_FILE, + field_delim=',', column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False) @@ -185,6 +186,26 @@ def test_csv_dataset_number(): assert np.allclose(buffer, [3.0, 0.3, 4, 55.5]) +def test_csv_dataset_field_delim_none(): + """ + Test CSV with field_delim=None + """ + TRAIN_FILE = '../data/dataset/testCSV/1.csv' + + buffer = [] + data = ds.CSVDataset( + TRAIN_FILE, + field_delim=None, + column_defaults=["0", 0, 0.0, "0"], + column_names=['1', '2', '3', '4'], + shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 4 + + def test_csv_dataset_size(): TEST_FILE = '../data/dataset/testCSV/size.csv' data = ds.CSVDataset( @@ -245,6 +266,7 @@ if __name__ == "__main__": test_csv_dataset_chinese() test_csv_dataset_header() test_csv_dataset_number() + test_csv_dataset_field_delim_none() test_csv_dataset_size() test_csv_dataset_type_error() test_csv_dataset_exception()