fix api doc

This commit is contained in:
liyong 2021-07-07 14:52:57 +08:00
parent e0adfec9a9
commit d79e74e68e
4 changed files with 22 additions and 9 deletions

View File

@ -36,7 +36,8 @@ class FileWriter:
Class to write user defined raw data into MindRecord files.
Note:
The mindrecord file may fail to be read if the file name is modified.
After the MindRecord file is generated, if the file name is changed,
the file may fail to be read.
Args:
file_name (str): File name of MindRecord file.
@ -266,7 +267,8 @@ class FileWriter:
the MindRecord file can store.
Args:
header_size (int): Size of header, between 16KB and 128MB.
header_size (int): Size of header, between 16*1024(16KB) and
128*1024*1024(128MB).
Returns:
MSRStatus, SUCCESS or FAILED.
@ -284,7 +286,8 @@ class FileWriter:
The larger a page, the more data the page can store.
Args:
page_size (int): Size of page, between 32KB and 256MB.
page_size (int): Size of page, between 32*1024(32KB) and
256*1024*1024(256MB).
Returns:
MSRStatus, SUCCESS or FAILED.

View File

@ -110,6 +110,12 @@ class Cifar100ToMR:
def transform(self, fields=None):
"""
Encapsulate the run function to exit normally
Args:
fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"].
Returns:
MSRStatus, whether cifar100 is successfully transformed to MindRecord.
"""
t = ExceptionThread(target=self.run, kwargs={'fields': fields})

View File

@ -106,6 +106,12 @@ class Cifar10ToMR:
def transform(self, fields=None):
"""
Encapsulate the run function to exit normally
Args:
fields (list[str], optional): A list of index fields. Default: None.
Returns:
MSRStatus, whether cifar10 is successfully transformed to MindRecord.
"""
t = ExceptionThread(target=self.run, kwargs={'fields': fields})

View File

@ -54,7 +54,7 @@ class CsvToMR:
else:
raise ValueError("The parameter source must be str.")
self.check_columns(columns_list, "columns_list")
self._check_columns(columns_list, "columns_list")
self.columns_list = columns_list
if isinstance(destination, str):
@ -72,8 +72,7 @@ class CsvToMR:
self.writer = FileWriter(self.destination, self.partition_number)
@staticmethod
def check_columns(columns, columns_name):
def _check_columns(self, columns, columns_name):
"""
Validate the columns of csv
"""
@ -111,8 +110,7 @@ class CsvToMR:
raise RuntimeError("Failed to generate schema from csv file.")
return schema
@staticmethod
def get_row_of_csv(df, columns_list):
def _get_row_of_csv(self, df, columns_list):
"""Get row data from csv file."""
for _, r in df.iterrows():
row = {}
@ -152,7 +150,7 @@ class CsvToMR:
# add the index
self.writer.add_index(list(self.columns_list))
csv_iter = self.get_row_of_csv(df, self.columns_list)
csv_iter = self._get_row_of_csv(df, self.columns_list)
batch_size = 256
transform_count = 0
while True: