forked from mindspore-Ecosystem/mindspore
!3124 [MD] error occur when using numpy types
Merge pull request !3124 from liyong126/fix_numpy_generic
This commit is contained in:
commit
ad651f38bf
|
@ -29,6 +29,7 @@ class ShardWriter:
|
|||
|
||||
The class would write MindRecord File series.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._writer = ms.ShardWriter()
|
||||
self._header = None
|
||||
|
@ -161,7 +162,7 @@ class ShardWriter:
|
|||
if row_blob:
|
||||
blob_data.append(list(row_blob))
|
||||
# filter raw data according to schema
|
||||
row_raw = {field: item[field]
|
||||
row_raw = {field: self._convert_np_types(item[field])
|
||||
for field in self._header.schema.keys() - self._header.blob_fields if field in item}
|
||||
if row_raw:
|
||||
raw_data.append(row_raw)
|
||||
|
@ -172,6 +173,12 @@ class ShardWriter:
|
|||
raise MRMWriteDatasetError
|
||||
return ret
|
||||
|
||||
def _convert_np_types(self, val):
|
||||
"""convert numpy type to python primitive type"""
|
||||
if isinstance(val, (np.int32, np.int64, np.float32, np.float64)):
|
||||
return val.item()
|
||||
return val
|
||||
|
||||
def _merge_blob(self, blob_data):
|
||||
"""
|
||||
Merge multiple blob data whose type is bytes or ndarray
|
||||
|
|
|
@ -1853,3 +1853,42 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(
|
|||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_numpy_generic():
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
cv_schema_json = {"label1": {"type": "int32"}, "label2": {"type": "int64"},
|
||||
"label3": {"type": "float32"}, "label4": {"type": "float64"}}
|
||||
data = []
|
||||
for idx in range(10):
|
||||
row = {}
|
||||
row['label1'] = np.int32(idx)
|
||||
row['label2'] = np.int64(idx*10)
|
||||
row['label3'] = np.float32(idx+0.12345)
|
||||
row['label4'] = np.float64(idx+0.12345789)
|
||||
data.append(row)
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, shuffle=False)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
idx = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert item['label1'] == item['label1']
|
||||
assert item['label2'] == item['label2']
|
||||
assert item['label3'] == item['label3']
|
||||
assert item['label4'] == item['label4']
|
||||
idx += 1
|
||||
assert idx == 10
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
|
Loading…
Reference in New Issue