From 8d4f831fdcb3e582389d23b4e80a02f9e6e62138 Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Thu, 18 Jun 2020 10:13:17 +0800 Subject: [PATCH] fix fastrcnn eval failed --- model_zoo/faster_rcnn/src/dataset.py | 30 ++++++++++++---------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/model_zoo/faster_rcnn/src/dataset.py b/model_zoo/faster_rcnn/src/dataset.py index 4f2d029be4b..d64de093919 100644 --- a/model_zoo/faster_rcnn/src/dataset.py +++ b/model_zoo/faster_rcnn/src/dataset.py @@ -464,27 +464,23 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi num_parallel_workers=num_parallel_workers) ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], operations=flipped_generation, num_parallel_workers=4) - - # transpose_column from python to c - ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) - ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) - ds = ds.map(input_columns=["box"], operations=[type_cast1]) - ds = ds.map(input_columns=["label"], operations=[type_cast2]) - ds = ds.map(input_columns=["valid_num"], operations=[type_cast3]) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) else: ds = ds.map(input_columns=["image", "annotation"], output_columns=["image", "image_shape", "box", "label", "valid_num"], columns_order=["image", "image_shape", "box", "label", "valid_num"], operations=compose_map_func, num_parallel_workers=num_parallel_workers) - # transpose_column from python to c - ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) - ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) - ds = ds.map(input_columns=["box"], operations=[type_cast1]) - ds = ds.map(input_columns=["label"], operations=[type_cast2]) - ds = ds.map(input_columns=["valid_num"], operations=[type_cast3]) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) + + ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], + num_parallel_workers=num_parallel_workers) + + # transpose_column from python to c + ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) + ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) + ds = ds.map(input_columns=["box"], operations=[type_cast1]) + ds = ds.map(input_columns=["label"], operations=[type_cast2]) + ds = ds.map(input_columns=["valid_num"], operations=[type_cast3]) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + return ds