forked from mindspore-Ecosystem/mindspore
convert csv to mindrecord
This commit is contained in:
parent
518e955260
commit
7369950a93
|
@ -29,10 +29,11 @@ from .common.exceptions import *
|
||||||
from .shardutils import SUCCESS, FAILED
|
from .shardutils import SUCCESS, FAILED
|
||||||
from .tools.cifar10_to_mr import Cifar10ToMR
|
from .tools.cifar10_to_mr import Cifar10ToMR
|
||||||
from .tools.cifar100_to_mr import Cifar100ToMR
|
from .tools.cifar100_to_mr import Cifar100ToMR
|
||||||
|
from .tools.csv_to_mr import CsvToMR
|
||||||
from .tools.imagenet_to_mr import ImageNetToMR
|
from .tools.imagenet_to_mr import ImageNetToMR
|
||||||
from .tools.mnist_to_mr import MnistToMR
|
from .tools.mnist_to_mr import MnistToMR
|
||||||
from .tools.tfrecord_to_mr import TFRecordToMR
|
from .tools.tfrecord_to_mr import TFRecordToMR
|
||||||
|
|
||||||
__all__ = ['FileWriter', 'FileReader', 'MindPage',
|
__all__ = ['FileWriter', 'FileReader', 'MindPage',
|
||||||
'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR',
|
'Cifar10ToMR', 'Cifar100ToMR', 'CsvToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR',
|
||||||
'SUCCESS', 'FAILED']
|
'SUCCESS', 'FAILED']
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Csv format convert tool for MindRecord.
|
||||||
|
"""
|
||||||
|
from importlib import import_module
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mindspore import log as logger
|
||||||
|
from ..filewriter import FileWriter
|
||||||
|
from ..shardutils import check_filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
pd = import_module("pandas")
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pd = None
|
||||||
|
|
||||||
|
__all__ = ['CsvToMR']
|
||||||
|
|
||||||
|
class CsvToMR:
|
||||||
|
"""
|
||||||
|
Class is for transformation from csv to MindRecord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (str): the file path of csv.
|
||||||
|
destination (str): the MindRecord file path to transform into.
|
||||||
|
columns_list(list[str], optional): List of columns to be read(default=None).
|
||||||
|
partition_number (int, optional): partition size (default=1).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If source, destination, partition_number is invalid.
|
||||||
|
RuntimeError: If columns_list is invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, source, destination, columns_list=None, partition_number=1):
|
||||||
|
if not pd:
|
||||||
|
raise Exception("Module pandas is not found, please use pip install it.")
|
||||||
|
if isinstance(source, str):
|
||||||
|
check_filename(source)
|
||||||
|
self.source = source
|
||||||
|
else:
|
||||||
|
raise ValueError("The parameter source must be str.")
|
||||||
|
|
||||||
|
self._check_columns(columns_list, "columns_list")
|
||||||
|
self.columns_list = columns_list
|
||||||
|
|
||||||
|
if isinstance(destination, str):
|
||||||
|
check_filename(destination)
|
||||||
|
self.destination = destination
|
||||||
|
else:
|
||||||
|
raise ValueError("The parameter destination must be str.")
|
||||||
|
|
||||||
|
if partition_number is not None:
|
||||||
|
if not isinstance(partition_number, int):
|
||||||
|
raise ValueError("The parameter partition_number must be int")
|
||||||
|
self.partition_number = partition_number
|
||||||
|
else:
|
||||||
|
raise ValueError("The parameter partition_number must be int")
|
||||||
|
|
||||||
|
self.writer = FileWriter(self.destination, self.partition_number)
|
||||||
|
|
||||||
|
def _check_columns(self, columns, columns_name):
|
||||||
|
if columns:
|
||||||
|
if isinstance(columns, list):
|
||||||
|
for col in columns:
|
||||||
|
if not isinstance(col, str):
|
||||||
|
raise ValueError("The parameter {} must be list of str.".format(columns_name))
|
||||||
|
else:
|
||||||
|
raise ValueError("The parameter {} must be list of str.".format(columns_name))
|
||||||
|
|
||||||
|
def _get_schema(self, df):
|
||||||
|
"""
|
||||||
|
Construct schema from df columns
|
||||||
|
"""
|
||||||
|
if self.columns_list:
|
||||||
|
for col in self.columns_list:
|
||||||
|
if col not in df.columns:
|
||||||
|
raise RuntimeError("The parameter columns_list is illegal, column {} does not exist.".format(col))
|
||||||
|
else:
|
||||||
|
self.columns_list = df.columns
|
||||||
|
|
||||||
|
schema = {}
|
||||||
|
for col in self.columns_list:
|
||||||
|
if str(df[col].dtype) == 'int64':
|
||||||
|
schema[col] = {"type": "int64"}
|
||||||
|
elif str(df[col].dtype) == 'float64':
|
||||||
|
schema[col] = {"type": "float64"}
|
||||||
|
elif str(df[col].dtype) == 'bool':
|
||||||
|
schema[col] = {"type": "int32"}
|
||||||
|
else:
|
||||||
|
schema[col] = {"type": "string"}
|
||||||
|
if not schema:
|
||||||
|
raise RuntimeError("Failed to generate schema from csv file.")
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def _get_row_of_csv(self, df):
|
||||||
|
"""Get row data from csv file."""
|
||||||
|
for _, r in df.iterrows():
|
||||||
|
row = {}
|
||||||
|
for col in self.columns_list:
|
||||||
|
if str(df[col].dtype) == 'bool':
|
||||||
|
row[col] = int(r[col])
|
||||||
|
else:
|
||||||
|
row[col] = r[col]
|
||||||
|
yield row
|
||||||
|
|
||||||
|
def transform(self):
|
||||||
|
"""
|
||||||
|
Executes transformation from csv to MindRecord.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(self.source):
|
||||||
|
raise IOError("Csv file {} do not exist.".format(self.source))
|
||||||
|
|
||||||
|
pd.set_option('display.max_columns', None)
|
||||||
|
df = pd.read_csv(self.source)
|
||||||
|
|
||||||
|
csv_schema = self._get_schema(df)
|
||||||
|
|
||||||
|
logger.info("transformed MindRecord schema is: {}".format(csv_schema))
|
||||||
|
|
||||||
|
# set the header size
|
||||||
|
self.writer.set_header_size(1 << 24)
|
||||||
|
|
||||||
|
# set the page size
|
||||||
|
self.writer.set_page_size(1 << 26)
|
||||||
|
|
||||||
|
# create the schema
|
||||||
|
self.writer.add_schema(csv_schema, "csv_schema")
|
||||||
|
|
||||||
|
# add the index
|
||||||
|
self.writer.add_index(list(self.columns_list))
|
||||||
|
|
||||||
|
csv_iter = self._get_row_of_csv(df)
|
||||||
|
batch_size = 256
|
||||||
|
transform_count = 0
|
||||||
|
while True:
|
||||||
|
data_list = []
|
||||||
|
try:
|
||||||
|
for _ in range(batch_size):
|
||||||
|
data_list.append(csv_iter.__next__())
|
||||||
|
transform_count += 1
|
||||||
|
self.writer.write_raw_data(data_list)
|
||||||
|
logger.info("transformed {} record...".format(transform_count))
|
||||||
|
except StopIteration:
|
||||||
|
if data_list:
|
||||||
|
self.writer.write_raw_data(data_list)
|
||||||
|
logger.info(
|
||||||
|
"transformed {} record...".format(transform_count))
|
||||||
|
break
|
||||||
|
|
||||||
|
ret = self.writer.commit()
|
||||||
|
|
||||||
|
return ret
|
|
@ -115,10 +115,8 @@ class TFRecordToMR:
|
||||||
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
|
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
|
||||||
bytes_fields (list): the bytes fields which are in feature_dict.
|
bytes_fields (list): the bytes fields which are in feature_dict.
|
||||||
|
|
||||||
Rasies:
|
Raises:
|
||||||
ValueError: the following condition will cause ValueError, 1) parameter TFRecord is not string, 2) parameter
|
ValueError: If parameter is invalid.
|
||||||
MindRecord is not string, 3) feature_dict is not FixedLenFeature, 4) parameter bytes_field is not list(str)
|
|
||||||
or not in feature_dict.
|
|
||||||
Exception: when tensorflow module not found or version is not correct.
|
Exception: when tensorflow module not found or version is not correct.
|
||||||
"""
|
"""
|
||||||
def __init__(self, source, destination, feature_dict, bytes_fields=None):
|
def __init__(self, source, destination, feature_dict, bytes_fields=None):
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
Age,EmployNumber,Name,Sales,Over18
|
||||||
|
21, 10023,john, 123.45,True
|
||||||
|
41, 10223,tom, 12111,True
|
||||||
|
51, 10231,bob, 8779.0,True
|
||||||
|
86, 10053,alice, 7777,True
|
||||||
|
26, 1053,carol, 12345.8,False
|
||||||
|
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test csv to mindrecord tool"""
|
||||||
|
import os
|
||||||
|
from importlib import import_module
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore import log as logger
|
||||||
|
from mindspore.mindrecord import FileReader
|
||||||
|
from mindspore.mindrecord import CsvToMR
|
||||||
|
|
||||||
|
try:
|
||||||
|
pd = import_module('pandas')
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pd = None
|
||||||
|
|
||||||
|
CSV_FILE = "../data/mindrecord/testCsv/data.csv"
|
||||||
|
MINDRECORD_FILE = "../data/mindrecord/testCsv/csv.mindrecord"
|
||||||
|
PARTITION_NUMBER = 4
|
||||||
|
|
||||||
|
@pytest.fixture(name="remove_mindrecord_file")
|
||||||
|
def fixture_remove():
|
||||||
|
"""add/remove file"""
|
||||||
|
def remove_one_file(x):
|
||||||
|
if os.path.exists(x):
|
||||||
|
os.remove(x)
|
||||||
|
def remove_file():
|
||||||
|
x = MINDRECORD_FILE
|
||||||
|
remove_one_file(x)
|
||||||
|
x = MINDRECORD_FILE + ".db"
|
||||||
|
remove_one_file(x)
|
||||||
|
for i in range(PARTITION_NUMBER):
|
||||||
|
x = MINDRECORD_FILE + str(i)
|
||||||
|
remove_one_file(x)
|
||||||
|
x = MINDRECORD_FILE + str(i) + ".db"
|
||||||
|
remove_one_file(x)
|
||||||
|
|
||||||
|
remove_file()
|
||||||
|
yield "yield_fixture_data"
|
||||||
|
remove_file()
|
||||||
|
|
||||||
|
def read(filename, columns, row_num):
|
||||||
|
"""test file reade"""
|
||||||
|
if not pd:
|
||||||
|
raise Exception("Module pandas is not found, please use pip install it.")
|
||||||
|
df = pd.read_csv(CSV_FILE)
|
||||||
|
count = 0
|
||||||
|
reader = FileReader(filename)
|
||||||
|
for _, x in enumerate(reader.get_next()):
|
||||||
|
for col in columns:
|
||||||
|
assert x[col] == df[col].iloc[count]
|
||||||
|
assert len(x) == len(columns)
|
||||||
|
count = count + 1
|
||||||
|
if count == 1:
|
||||||
|
logger.info("data: {}".format(x))
|
||||||
|
assert count == row_num
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
def test_csv_to_mindrecord(remove_mindrecord_file):
|
||||||
|
"""test transform csv to mindrecord."""
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=PARTITION_NUMBER)
|
||||||
|
csv_trans.transform()
|
||||||
|
for i in range(PARTITION_NUMBER):
|
||||||
|
assert os.path.exists(MINDRECORD_FILE + str(i))
|
||||||
|
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
|
||||||
|
read(MINDRECORD_FILE + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
|
||||||
|
|
||||||
|
def test_csv_to_mindrecord_with_columns(remove_mindrecord_file):
|
||||||
|
"""test transform csv to mindrecord."""
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER)
|
||||||
|
csv_trans.transform()
|
||||||
|
for i in range(PARTITION_NUMBER):
|
||||||
|
assert os.path.exists(MINDRECORD_FILE + str(i))
|
||||||
|
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
|
||||||
|
read(MINDRECORD_FILE + "0", ["Age", "Sales"], 5)
|
||||||
|
|
||||||
|
def test_csv_to_mindrecord_with_no_exist_columns(remove_mindrecord_file):
|
||||||
|
"""test transform csv to mindrecord."""
|
||||||
|
with pytest.raises(Exception, match="The parameter columns_list is illegal, column ssales does not exist."):
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'ssales'],
|
||||||
|
partition_number=PARTITION_NUMBER)
|
||||||
|
csv_trans.transform()
|
||||||
|
|
||||||
|
def test_csv_partition_number_with_illegal_columns(remove_mindrecord_file):
|
||||||
|
"""
|
||||||
|
test transform csv to mindrecord
|
||||||
|
"""
|
||||||
|
with pytest.raises(Exception, match="The parameter columns_list must be list of str."):
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, ["Sales", 2])
|
||||||
|
csv_trans.transform()
|
||||||
|
|
||||||
|
|
||||||
|
def test_csv_to_mindrecord_default_partition_number(remove_mindrecord_file):
|
||||||
|
"""
|
||||||
|
test transform csv to mindrecord
|
||||||
|
when partition number is default.
|
||||||
|
"""
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE)
|
||||||
|
csv_trans.transform()
|
||||||
|
assert os.path.exists(MINDRECORD_FILE)
|
||||||
|
assert os.path.exists(MINDRECORD_FILE + ".db")
|
||||||
|
read(MINDRECORD_FILE, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
|
||||||
|
|
||||||
|
def test_csv_partition_number_0(remove_mindrecord_file):
|
||||||
|
"""
|
||||||
|
test transform csv to mindrecord
|
||||||
|
when partition number is 0.
|
||||||
|
"""
|
||||||
|
with pytest.raises(Exception, match="Invalid parameter value"):
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, 0)
|
||||||
|
csv_trans.transform()
|
||||||
|
|
||||||
|
def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file):
|
||||||
|
"""
|
||||||
|
test transform csv to mindrecord
|
||||||
|
when partition number is none.
|
||||||
|
"""
|
||||||
|
with pytest.raises(Exception,
|
||||||
|
match="The parameter partition_number must be int"):
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, None)
|
||||||
|
csv_trans.transform()
|
||||||
|
|
||||||
|
def test_csv_to_mindrecord_illegal_filename(remove_mindrecord_file):
|
||||||
|
"""
|
||||||
|
test transform csv to mindrecord
|
||||||
|
when file name contains illegal character.
|
||||||
|
"""
|
||||||
|
filename = "not_*ok"
|
||||||
|
with pytest.raises(Exception, match="File name should not contains"):
|
||||||
|
csv_trans = CsvToMR(CSV_FILE, filename)
|
||||||
|
csv_trans.transform()
|
Loading…
Reference in New Issue