forked from mindspore-Ecosystem/mindspore
dataset Python Pushdown - more minor code rework; CSVDataset delim_field bug fix
This commit is contained in:
parent
73c91e05b1
commit
69b269c6cf
|
@ -191,8 +191,6 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
|||
Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
|
||||
if (s.IsOk()) {
|
||||
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
|
||||
}
|
||||
if (tracing_ != nullptr) {
|
||||
cur_connector_size_ = tree_->root()->ConnectorSize();
|
||||
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
|
||||
}
|
||||
|
|
|
@ -382,30 +382,30 @@ class SchemaObj {
|
|||
Status Init();
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(TypeId).
|
||||
/// \return bool true if schema init success
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] de_type Data type of the column(TypeId).
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, TypeId de_type);
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(std::string).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] de_type Data type of the column(std::string).
|
||||
/// \param[in] shape Shape of the column.
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, const std::string &de_type);
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(TypeId).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] de_type Data type of the column(TypeId).
|
||||
/// \param[in] shape Shape of the column.
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, TypeId de_type, const std::vector<int32_t> &shape);
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(std::string).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] de_type Data type of the column(std::string).
|
||||
/// \param[in] shape Shape of the column.
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, const std::string &de_type, const std::vector<int32_t> &shape);
|
||||
|
||||
/// \brief Get a JSON string of the schema
|
||||
|
@ -415,29 +415,35 @@ class SchemaObj {
|
|||
/// \brief Get a JSON string of the schema
|
||||
std::string to_string() { return to_json(); }
|
||||
|
||||
/// \brief set a new value to dataset_type
|
||||
/// \brief Set a new value to dataset_type
|
||||
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); }
|
||||
|
||||
/// \brief set a new value to num_rows
|
||||
/// \brief Set a new value to num_rows
|
||||
inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
|
||||
|
||||
/// \brief get the current num_rows
|
||||
/// \brief Get the current num_rows
|
||||
inline int32_t get_num_rows() const { return num_rows_; }
|
||||
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_string Name of JSON file to be parsed.
|
||||
/// \return Status code
|
||||
Status FromJSONString(const std::string &json_string);
|
||||
|
||||
/// \brief Parse and add column information
|
||||
/// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
|
||||
/// \return Status code
|
||||
Status ParseColumnString(const std::string &json_string);
|
||||
|
||||
private:
|
||||
/// \brief Parse the columns and add it to columns
|
||||
/// \param[in] columns dataset attribution information, decoded from schema file.
|
||||
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
|
||||
/// \return JSON string of the schema
|
||||
/// \brief Parse the columns and add them to columns
|
||||
/// \param[in] columns Dataset attribution information, decoded from schema file.
|
||||
/// Support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
|
||||
/// \return Status code
|
||||
Status parse_column(nlohmann::json columns);
|
||||
|
||||
/// \brief Get schema file from json file
|
||||
/// \param[in] json_obj object of json parsed.
|
||||
/// \return bool true if json dump success
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_obj Object of JSON parsed.
|
||||
/// \return Status code
|
||||
Status from_json(nlohmann::json json_obj);
|
||||
|
||||
int32_t num_rows_;
|
||||
|
|
|
@ -5357,7 +5357,7 @@ class CSVDataset(SourceDataset):
|
|||
super().__init__(num_parallel_workers=num_parallel_workers)
|
||||
self.dataset_files = self._find_files(dataset_files)
|
||||
self.dataset_files.sort()
|
||||
self.field_delim = replace_none(field_delim, '')
|
||||
self.field_delim = replace_none(field_delim, ',')
|
||||
self.column_defaults = replace_none(column_defaults, [])
|
||||
self.column_names = replace_none(column_names, [])
|
||||
self.num_samples = num_samples
|
||||
|
|
|
@ -924,9 +924,10 @@ def check_csvdataset(method):
|
|||
|
||||
# check field_delim
|
||||
field_delim = param_dict.get('field_delim')
|
||||
type_check(field_delim, (str,), 'field delim')
|
||||
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
|
||||
raise ValueError("field_delim is invalid.")
|
||||
if field_delim is not None:
|
||||
type_check(field_delim, (str,), 'field delim')
|
||||
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
|
||||
raise ValueError("field_delim is invalid.")
|
||||
|
||||
# check column_defaults
|
||||
column_defaults = param_dict.get('column_defaults')
|
||||
|
|
|
@ -28,6 +28,7 @@ def test_csv_dataset_basic():
|
|||
buffer = []
|
||||
data = ds.CSVDataset(
|
||||
TRAIN_FILE,
|
||||
field_delim=',',
|
||||
column_defaults=["0", 0, 0.0, "0"],
|
||||
column_names=['1', '2', '3', '4'],
|
||||
shuffle=False)
|
||||
|
@ -185,6 +186,26 @@ def test_csv_dataset_number():
|
|||
assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
|
||||
|
||||
|
||||
def test_csv_dataset_field_delim_none():
|
||||
"""
|
||||
Test CSV with field_delim=None
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCSV/1.csv'
|
||||
|
||||
buffer = []
|
||||
data = ds.CSVDataset(
|
||||
TRAIN_FILE,
|
||||
field_delim=None,
|
||||
column_defaults=["0", 0, 0.0, "0"],
|
||||
column_names=['1', '2', '3', '4'],
|
||||
shuffle=False)
|
||||
data = data.repeat(2)
|
||||
data = data.skip(2)
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
buffer.append(d)
|
||||
assert len(buffer) == 4
|
||||
|
||||
|
||||
def test_csv_dataset_size():
|
||||
TEST_FILE = '../data/dataset/testCSV/size.csv'
|
||||
data = ds.CSVDataset(
|
||||
|
@ -245,6 +266,7 @@ if __name__ == "__main__":
|
|||
test_csv_dataset_chinese()
|
||||
test_csv_dataset_header()
|
||||
test_csv_dataset_number()
|
||||
test_csv_dataset_field_delim_none()
|
||||
test_csv_dataset_size()
|
||||
test_csv_dataset_type_error()
|
||||
test_csv_dataset_exception()
|
||||
|
|
Loading…
Reference in New Issue