diff --git a/model_zoo/faster_rcnn/src/dataset.py b/model_zoo/faster_rcnn/src/dataset.py index 4f2d029be4b..d64de093919 100644 --- a/model_zoo/faster_rcnn/src/dataset.py +++ b/model_zoo/faster_rcnn/src/dataset.py @@ -464,27 +464,23 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi num_parallel_workers=num_parallel_workers) ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], operations=flipped_generation, num_parallel_workers=4) - - # transpose_column from python to c - ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) - ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) - ds = ds.map(input_columns=["box"], operations=[type_cast1]) - ds = ds.map(input_columns=["label"], operations=[type_cast2]) - ds = ds.map(input_columns=["valid_num"], operations=[type_cast3]) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) else: ds = ds.map(input_columns=["image", "annotation"], output_columns=["image", "image_shape", "box", "label", "valid_num"], columns_order=["image", "image_shape", "box", "label", "valid_num"], operations=compose_map_func, num_parallel_workers=num_parallel_workers) - # transpose_column from python to c - ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) - ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) - ds = ds.map(input_columns=["box"], operations=[type_cast1]) - ds = ds.map(input_columns=["label"], operations=[type_cast2]) - ds = ds.map(input_columns=["valid_num"], operations=[type_cast3]) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) + + ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], + num_parallel_workers=num_parallel_workers) + + # transpose_column from python to c + ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) + ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) + ds = ds.map(input_columns=["box"], operations=[type_cast1]) + ds = ds.map(input_columns=["label"], operations=[type_cast2]) + ds = ds.map(input_columns=["valid_num"], operations=[type_cast3]) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + return ds