forked from mindspore-Ecosystem/mindspore
!9306 fix ndarray field without type in mindrecord
From: @jonyguo Reviewed-by: @pandoublefeng,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
77ef72be30
|
@ -196,10 +196,10 @@ class ShardWriter:
|
||||||
def int_to_bytes(x: int) -> bytes:
|
def int_to_bytes(x: int) -> bytes:
|
||||||
return x.to_bytes(8, 'big')
|
return x.to_bytes(8, 'big')
|
||||||
merged = bytes()
|
merged = bytes()
|
||||||
for _, v in blob_data.items():
|
for field, v in blob_data.items():
|
||||||
# convert ndarray to bytes
|
# convert ndarray to bytes
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
v = v.tobytes()
|
v = v.astype(self._header.schema[field]["type"]).tobytes()
|
||||||
merged += int_to_bytes(len(v))
|
merged += int_to_bytes(len(v))
|
||||||
merged += v
|
merged += v
|
||||||
return merged
|
return merged
|
||||||
|
|
|
@ -964,3 +964,39 @@ def test_write_read_process_with_multi_bytes_and_array():
|
||||||
|
|
||||||
os.remove("{}".format(mindrecord_file_name))
|
os.remove("{}".format(mindrecord_file_name))
|
||||||
os.remove("{}.db".format(mindrecord_file_name))
|
os.remove("{}.db".format(mindrecord_file_name))
|
||||||
|
|
||||||
|
def test_write_read_process_without_ndarray_type():
|
||||||
|
mindrecord_file_name = "test.mindrecord"
|
||||||
|
# field: mask derivation type is int64, but schema type is int32
|
||||||
|
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9]),
|
||||||
|
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
|
||||||
|
"data": bytes("image bytes abc", encoding='UTF-8')}
|
||||||
|
]
|
||||||
|
writer = FileWriter(mindrecord_file_name)
|
||||||
|
schema = {"file_name": {"type": "string"},
|
||||||
|
"label": {"type": "int32"},
|
||||||
|
"score": {"type": "float64"},
|
||||||
|
"mask": {"type": "int32", "shape": [-1]},
|
||||||
|
"segments": {"type": "float32", "shape": [2, 2]},
|
||||||
|
"data": {"type": "bytes"}}
|
||||||
|
writer.add_schema(schema, "data is so cool")
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
reader = FileReader(mindrecord_file_name)
|
||||||
|
count = 0
|
||||||
|
for index, x in enumerate(reader.get_next()):
|
||||||
|
assert len(x) == 6
|
||||||
|
for field in x:
|
||||||
|
if isinstance(x[field], np.ndarray):
|
||||||
|
print("output: {}, input: {}".format(x[field], data[count][field]))
|
||||||
|
assert (x[field] == data[count][field]).all()
|
||||||
|
else:
|
||||||
|
assert x[field] == data[count][field]
|
||||||
|
count = count + 1
|
||||||
|
logger.info("#item{}: {}".format(index, x))
|
||||||
|
assert count == 1
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
os.remove("{}".format(mindrecord_file_name))
|
||||||
|
os.remove("{}.db".format(mindrecord_file_name))
|
||||||
|
|
Loading…
Reference in New Issue