!2262 fix fast_rcnn eval failed

Merge pull request !2262 from yanghaitao/yht_fasterrcn
This commit is contained in:
mindspore-ci-bot 2020-06-18 11:44:42 +08:00 committed by Gitee
commit b106c2204a
1 changed files with 13 additions and 17 deletions

View File

@ -464,27 +464,23 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
operations=flipped_generation, num_parallel_workers=4)
# transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2])
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
else:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
columns_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func,
num_parallel_workers=num_parallel_workers)
# transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2])
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
# transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2])
ds = ds.map(input_columns=["valid_num"], operations=[type_cast3])
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds