forked from OSSInnovation/mindspore
!2262 fix fast_rcnn eval failed
Merge pull request !2262 from yanghaitao/yht_fasterrcn
This commit is contained in:
commit
b106c2204a
|
@ -464,27 +464,23 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
|
|||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
operations=flipped_generation, num_parallel_workers=4)
|
||||
|
||||
# transpose_column from python to c
|
||||
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
|
||||
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
|
||||
ds = ds.map(input_columns=["box"], operations=[type_cast1])
|
||||
ds = ds.map(input_columns=["label"], operations=[type_cast2])
|
||||
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
else:
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
columns_order=["image", "image_shape", "box", "label", "valid_num"],
|
||||
operations=compose_map_func,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
# transpose_column from python to c
|
||||
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
|
||||
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
|
||||
ds = ds.map(input_columns=["box"], operations=[type_cast1])
|
||||
ds = ds.map(input_columns=["label"], operations=[type_cast2])
|
||||
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# transpose_column from python to c
|
||||
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
|
||||
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
|
||||
ds = ds.map(input_columns=["box"], operations=[type_cast1])
|
||||
ds = ds.map(input_columns=["label"], operations=[type_cast2])
|
||||
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
||||
|
|
Loading…
Reference in New Issue