From 88afef5e90814c191e89cd2e32c5b3cdb52b3759 Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Wed, 1 Jul 2020 15:43:01 +0800 Subject: [PATCH] fix fastrcnn accuracy --- model_zoo/faster_rcnn/src/dataset.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/model_zoo/faster_rcnn/src/dataset.py b/model_zoo/faster_rcnn/src/dataset.py index 133824dd247..346ed5a6cd1 100644 --- a/model_zoo/faster_rcnn/src/dataset.py +++ b/model_zoo/faster_rcnn/src/dataset.py @@ -441,6 +441,7 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi hwc_to_chw = C.HWC2CHW() normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) horizontally_op = C.RandomHorizontalFlip(1) + type_cast0 = CC.TypeCast(mstype.float32) type_cast1 = CC.TypeCast(mstype.float16) type_cast2 = CC.TypeCast(mstype.int32) type_cast3 = CC.TypeCast(mstype.bool_) @@ -453,13 +454,15 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi flip = (np.random.rand() < config.flip_ratio) if flip: - ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1], - num_parallel_workers=24) + ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0, horizontally_op], + num_parallel_workers=12) ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], operations=flipped_generation, num_parallel_workers=num_parallel_workers) else: - ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1], - num_parallel_workers=24) + ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], + num_parallel_workers=12) + ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1], + num_parallel_workers=12) else: ds = ds.map(input_columns=["image", "annotation"],