forked from mindspore-Ecosystem/mindspore
!384 [MD] remove validation parameter in write_raw_data
Merge pull request !384 from liyong126/mindrecord_validation
This commit is contained in:
commit
d1b452cf3a
|
@ -26,8 +26,7 @@ from .shardheader import ShardHeader
|
|||
from .shardindexgenerator import ShardIndexGenerator
|
||||
from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \
|
||||
check_filename, VALUE_TYPE_MAP
|
||||
from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError, \
|
||||
MRMValidateDataError
|
||||
from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError
|
||||
|
||||
__all__ = ['FileWriter']
|
||||
|
||||
|
@ -201,52 +200,13 @@ class FileWriter:
|
|||
raw_data.pop(i)
|
||||
logger.warning(v)
|
||||
|
||||
def _verify_based_on_blob_fields(self, raw_data):
|
||||
def write_raw_data(self, raw_data):
|
||||
"""
|
||||
Verify data according to blob fields which is sub set of schema's fields.
|
||||
|
||||
Raise exception if validation failed.
|
||||
1) allowed data type contains: "int32", "int64", "float32", "float64", "string", "bytes".
|
||||
|
||||
Args:
|
||||
raw_data (list[dict]): List of raw data.
|
||||
|
||||
Raises:
|
||||
MRMValidateDataError: If data does not match blob fields.
|
||||
"""
|
||||
schema_content = self._header.schema
|
||||
for field in schema_content:
|
||||
for i, v in enumerate(raw_data):
|
||||
if field not in v:
|
||||
raise MRMValidateDataError("for schema, {} th data is wrong: "\
|
||||
"there is not '{}' object in the raw data.".format(i, field))
|
||||
if field in self._header.blob_fields:
|
||||
field_type = type(v[field]).__name__
|
||||
if field_type not in VALUE_TYPE_MAP:
|
||||
raise MRMValidateDataError("for schema, {} th data is wrong: "\
|
||||
"data type for '{}' is not matched.".format(i, field))
|
||||
if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]:
|
||||
raise MRMValidateDataError("for schema, {} th data is wrong: "\
|
||||
"data type for '{}' is not matched.".format(i, field))
|
||||
if field_type == 'ndarray':
|
||||
if 'shape' not in schema_content[field]:
|
||||
raise MRMValidateDataError("for schema, {} th data is wrong: " \
|
||||
"data type for '{}' is not matched.".format(i, field))
|
||||
try:
|
||||
# tuple or list
|
||||
np.reshape(v[field], schema_content[field]['shape'])
|
||||
except ValueError:
|
||||
raise MRMValidateDataError("for schema, {} th data is wrong: " \
|
||||
"data type for '{}' is not matched.".format(i, field))
|
||||
|
||||
def write_raw_data(self, raw_data, validate=True):
|
||||
"""
|
||||
Write raw data and generate sequential pair of MindRecord File.
|
||||
Write raw data and generate sequential pair of MindRecord File and \
|
||||
validate data based on predefined schema by default.
|
||||
|
||||
Args:
|
||||
raw_data (list[dict]): List of raw data.
|
||||
validate (bool, optional): Validate data according schema if it equals to True,
|
||||
or validate data according to blob fields (default=True).
|
||||
|
||||
Raises:
|
||||
ParamTypeError: If index field is invalid.
|
||||
|
@ -264,11 +224,8 @@ class FileWriter:
|
|||
for each_raw in raw_data:
|
||||
if not isinstance(each_raw, dict):
|
||||
raise ParamTypeError('raw_data item', 'dict')
|
||||
if validate is True:
|
||||
self._verify_based_on_schema(raw_data)
|
||||
elif validate is False:
|
||||
self._verify_based_on_blob_fields(raw_data)
|
||||
return self._writer.write_raw_data(raw_data, validate)
|
||||
self._verify_based_on_schema(raw_data)
|
||||
return self._writer.write_raw_data(raw_data, True)
|
||||
|
||||
def set_header_size(self, header_size):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue