forked from mindspore-Ecosystem/mindspore
!2262 fix fast_rcnn eval failed
Merge pull request !2262 from yanghaitao/yht_fasterrcn
This commit is contained in:
commit
b106c2204a
|
@ -464,27 +464,23 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
|
||||||
num_parallel_workers=num_parallel_workers)
|
num_parallel_workers=num_parallel_workers)
|
||||||
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
|
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
operations=flipped_generation, num_parallel_workers=4)
|
operations=flipped_generation, num_parallel_workers=4)
|
||||||
|
|
||||||
# transpose_column from python to c
|
|
||||||
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
|
|
||||||
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
|
|
||||||
ds = ds.map(input_columns=["box"], operations=[type_cast1])
|
|
||||||
ds = ds.map(input_columns=["label"], operations=[type_cast2])
|
|
||||||
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
|
|
||||||
ds = ds.batch(batch_size, drop_remainder=True)
|
|
||||||
ds = ds.repeat(repeat_num)
|
|
||||||
else:
|
else:
|
||||||
ds = ds.map(input_columns=["image", "annotation"],
|
ds = ds.map(input_columns=["image", "annotation"],
|
||||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
columns_order=["image", "image_shape", "box", "label", "valid_num"],
|
columns_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||||
operations=compose_map_func,
|
operations=compose_map_func,
|
||||||
num_parallel_workers=num_parallel_workers)
|
num_parallel_workers=num_parallel_workers)
|
||||||
# transpose_column from python to c
|
|
||||||
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
|
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
|
||||||
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
|
num_parallel_workers=num_parallel_workers)
|
||||||
ds = ds.map(input_columns=["box"], operations=[type_cast1])
|
|
||||||
ds = ds.map(input_columns=["label"], operations=[type_cast2])
|
# transpose_column from python to c
|
||||||
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
|
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
|
||||||
ds = ds.batch(batch_size, drop_remainder=True)
|
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
|
||||||
ds = ds.repeat(repeat_num)
|
ds = ds.map(input_columns=["box"], operations=[type_cast1])
|
||||||
|
ds = ds.map(input_columns=["label"], operations=[type_cast2])
|
||||||
|
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
ds = ds.repeat(repeat_num)
|
||||||
|
|
||||||
return ds
|
return ds
|
||||||
|
|
Loading…
Reference in New Issue