From 7369950a939ea4550990d6312191ea84469b4407 Mon Sep 17 00:00:00 2001 From: liyong Date: Wed, 17 Jun 2020 15:47:04 +0800 Subject: [PATCH] convert csv to mindrecord --- mindspore/mindrecord/__init__.py | 3 +- mindspore/mindrecord/tools/csv_to_mr.py | 168 ++++++++++++++++++ mindspore/mindrecord/tools/tfrecord_to_mr.py | 6 +- tests/ut/data/mindrecord/testCsv/data.csv | 7 + .../mindrecord/test_csv_to_mindrecord.py | 143 +++++++++++++++ 5 files changed, 322 insertions(+), 5 deletions(-) create mode 100644 mindspore/mindrecord/tools/csv_to_mr.py create mode 100644 tests/ut/data/mindrecord/testCsv/data.csv create mode 100644 tests/ut/python/mindrecord/test_csv_to_mindrecord.py diff --git a/mindspore/mindrecord/__init__.py b/mindspore/mindrecord/__init__.py index ba686c6c183..ee23b68cb66 100644 --- a/mindspore/mindrecord/__init__.py +++ b/mindspore/mindrecord/__init__.py @@ -29,10 +29,11 @@ from .common.exceptions import * from .shardutils import SUCCESS, FAILED from .tools.cifar10_to_mr import Cifar10ToMR from .tools.cifar100_to_mr import Cifar100ToMR +from .tools.csv_to_mr import CsvToMR from .tools.imagenet_to_mr import ImageNetToMR from .tools.mnist_to_mr import MnistToMR from .tools.tfrecord_to_mr import TFRecordToMR __all__ = ['FileWriter', 'FileReader', 'MindPage', - 'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR', + 'Cifar10ToMR', 'Cifar100ToMR', 'CsvToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR', 'SUCCESS', 'FAILED'] diff --git a/mindspore/mindrecord/tools/csv_to_mr.py b/mindspore/mindrecord/tools/csv_to_mr.py new file mode 100644 index 00000000000..4bc8f37b476 --- /dev/null +++ b/mindspore/mindrecord/tools/csv_to_mr.py @@ -0,0 +1,168 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Csv format convert tool for MindRecord. +""" +from importlib import import_module +import os + +from mindspore import log as logger +from ..filewriter import FileWriter +from ..shardutils import check_filename + +try: + pd = import_module("pandas") +except ModuleNotFoundError: + pd = None + +__all__ = ['CsvToMR'] + +class CsvToMR: + """ + Class is for transformation from csv to MindRecord. + + Args: + source (str): the file path of csv. + destination (str): the MindRecord file path to transform into. + columns_list(list[str], optional): List of columns to be read(default=None). + partition_number (int, optional): partition size (default=1). + + Raises: + ValueError: If source, destination, partition_number is invalid. + RuntimeError: If columns_list is invalid. + """ + + def __init__(self, source, destination, columns_list=None, partition_number=1): + if not pd: + raise Exception("Module pandas is not found, please use pip install it.") + if isinstance(source, str): + check_filename(source) + self.source = source + else: + raise ValueError("The parameter source must be str.") + + self._check_columns(columns_list, "columns_list") + self.columns_list = columns_list + + if isinstance(destination, str): + check_filename(destination) + self.destination = destination + else: + raise ValueError("The parameter destination must be str.") + + if partition_number is not None: + if not isinstance(partition_number, int): + raise ValueError("The parameter partition_number must be int") + self.partition_number = partition_number + else: + raise ValueError("The parameter partition_number must be int") + + self.writer = FileWriter(self.destination, self.partition_number) + + def _check_columns(self, columns, columns_name): + if columns: + if isinstance(columns, list): + for col in columns: + if not isinstance(col, str): + raise ValueError("The parameter {} must be list of str.".format(columns_name)) + else: + raise ValueError("The parameter {} must be list of str.".format(columns_name)) + + def _get_schema(self, df): + """ + Construct schema from df columns + """ + if self.columns_list: + for col in self.columns_list: + if col not in df.columns: + raise RuntimeError("The parameter columns_list is illegal, column {} does not exist.".format(col)) + else: + self.columns_list = df.columns + + schema = {} + for col in self.columns_list: + if str(df[col].dtype) == 'int64': + schema[col] = {"type": "int64"} + elif str(df[col].dtype) == 'float64': + schema[col] = {"type": "float64"} + elif str(df[col].dtype) == 'bool': + schema[col] = {"type": "int32"} + else: + schema[col] = {"type": "string"} + if not schema: + raise RuntimeError("Failed to generate schema from csv file.") + return schema + + def _get_row_of_csv(self, df): + """Get row data from csv file.""" + for _, r in df.iterrows(): + row = {} + for col in self.columns_list: + if str(df[col].dtype) == 'bool': + row[col] = int(r[col]) + else: + row[col] = r[col] + yield row + + def transform(self): + """ + Executes transformation from csv to MindRecord. + + Returns: + SUCCESS/FAILED, whether successfully written into MindRecord. + """ + if not os.path.exists(self.source): + raise IOError("Csv file {} do not exist.".format(self.source)) + + pd.set_option('display.max_columns', None) + df = pd.read_csv(self.source) + + csv_schema = self._get_schema(df) + + logger.info("transformed MindRecord schema is: {}".format(csv_schema)) + + # set the header size + self.writer.set_header_size(1 << 24) + + # set the page size + self.writer.set_page_size(1 << 26) + + # create the schema + self.writer.add_schema(csv_schema, "csv_schema") + + # add the index + self.writer.add_index(list(self.columns_list)) + + csv_iter = self._get_row_of_csv(df) + batch_size = 256 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data_list.append(csv_iter.__next__()) + transform_count += 1 + self.writer.write_raw_data(data_list) + logger.info("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + self.writer.write_raw_data(data_list) + logger.info( + "transformed {} record...".format(transform_count)) + break + + ret = self.writer.commit() + + return ret diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index 8ae5aa12443..e8c52001fdd 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -115,10 +115,8 @@ class TFRecordToMR: "sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}} bytes_fields (list): the bytes fields which are in feature_dict. - Rasies: - ValueError: the following condition will cause ValueError, 1) parameter TFRecord is not string, 2) parameter - MindRecord is not string, 3) feature_dict is not FixedLenFeature, 4) parameter bytes_field is not list(str) - or not in feature_dict. + Raises: + ValueError: If parameter is invalid. Exception: when tensorflow module not found or version is not correct. """ def __init__(self, source, destination, feature_dict, bytes_fields=None): diff --git a/tests/ut/data/mindrecord/testCsv/data.csv b/tests/ut/data/mindrecord/testCsv/data.csv new file mode 100644 index 00000000000..8dad64b3c9f --- /dev/null +++ b/tests/ut/data/mindrecord/testCsv/data.csv @@ -0,0 +1,7 @@ +Age,EmployNumber,Name,Sales,Over18 +21, 10023,john, 123.45,True +41, 10223,tom, 12111,True +51, 10231,bob, 8779.0,True +86, 10053,alice, 7777,True +26, 1053,carol, 12345.8,False + diff --git a/tests/ut/python/mindrecord/test_csv_to_mindrecord.py b/tests/ut/python/mindrecord/test_csv_to_mindrecord.py new file mode 100644 index 00000000000..02c19359f22 --- /dev/null +++ b/tests/ut/python/mindrecord/test_csv_to_mindrecord.py @@ -0,0 +1,143 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test csv to mindrecord tool""" +import os +from importlib import import_module +import pytest + +from mindspore import log as logger +from mindspore.mindrecord import FileReader +from mindspore.mindrecord import CsvToMR + +try: + pd = import_module('pandas') +except ModuleNotFoundError: + pd = None + +CSV_FILE = "../data/mindrecord/testCsv/data.csv" +MINDRECORD_FILE = "../data/mindrecord/testCsv/csv.mindrecord" +PARTITION_NUMBER = 4 + +@pytest.fixture(name="remove_mindrecord_file") +def fixture_remove(): + """add/remove file""" + def remove_one_file(x): + if os.path.exists(x): + os.remove(x) + def remove_file(): + x = MINDRECORD_FILE + remove_one_file(x) + x = MINDRECORD_FILE + ".db" + remove_one_file(x) + for i in range(PARTITION_NUMBER): + x = MINDRECORD_FILE + str(i) + remove_one_file(x) + x = MINDRECORD_FILE + str(i) + ".db" + remove_one_file(x) + + remove_file() + yield "yield_fixture_data" + remove_file() + +def read(filename, columns, row_num): + """test file reade""" + if not pd: + raise Exception("Module pandas is not found, please use pip install it.") + df = pd.read_csv(CSV_FILE) + count = 0 + reader = FileReader(filename) + for _, x in enumerate(reader.get_next()): + for col in columns: + assert x[col] == df[col].iloc[count] + assert len(x) == len(columns) + count = count + 1 + if count == 1: + logger.info("data: {}".format(x)) + assert count == row_num + reader.close() + +def test_csv_to_mindrecord(remove_mindrecord_file): + """test transform csv to mindrecord.""" + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=PARTITION_NUMBER) + csv_trans.transform() + for i in range(PARTITION_NUMBER): + assert os.path.exists(MINDRECORD_FILE + str(i)) + assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") + read(MINDRECORD_FILE + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5) + +def test_csv_to_mindrecord_with_columns(remove_mindrecord_file): + """test transform csv to mindrecord.""" + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER) + csv_trans.transform() + for i in range(PARTITION_NUMBER): + assert os.path.exists(MINDRECORD_FILE + str(i)) + assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") + read(MINDRECORD_FILE + "0", ["Age", "Sales"], 5) + +def test_csv_to_mindrecord_with_no_exist_columns(remove_mindrecord_file): + """test transform csv to mindrecord.""" + with pytest.raises(Exception, match="The parameter columns_list is illegal, column ssales does not exist."): + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'ssales'], + partition_number=PARTITION_NUMBER) + csv_trans.transform() + +def test_csv_partition_number_with_illegal_columns(remove_mindrecord_file): + """ + test transform csv to mindrecord + """ + with pytest.raises(Exception, match="The parameter columns_list must be list of str."): + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, ["Sales", 2]) + csv_trans.transform() + + +def test_csv_to_mindrecord_default_partition_number(remove_mindrecord_file): + """ + test transform csv to mindrecord + when partition number is default. + """ + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE) + csv_trans.transform() + assert os.path.exists(MINDRECORD_FILE) + assert os.path.exists(MINDRECORD_FILE + ".db") + read(MINDRECORD_FILE, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5) + +def test_csv_partition_number_0(remove_mindrecord_file): + """ + test transform csv to mindrecord + when partition number is 0. + """ + with pytest.raises(Exception, match="Invalid parameter value"): + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, 0) + csv_trans.transform() + +def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file): + """ + test transform csv to mindrecord + when partition number is none. + """ + with pytest.raises(Exception, + match="The parameter partition_number must be int"): + csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, None) + csv_trans.transform() + +def test_csv_to_mindrecord_illegal_filename(remove_mindrecord_file): + """ + test transform csv to mindrecord + when file name contains illegal character. + """ + filename = "not_*ok" + with pytest.raises(Exception, match="File name should not contains"): + csv_trans = CsvToMR(CSV_FILE, filename) + csv_trans.transform()