forked from mindspore-Ecosystem/mindspore
fix fastrcnn accuracy
This commit is contained in:
parent
4eafb21ea8
commit
88afef5e90
|
@ -441,6 +441,7 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
|
|||
hwc_to_chw = C.HWC2CHW()
|
||||
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
|
||||
horizontally_op = C.RandomHorizontalFlip(1)
|
||||
type_cast0 = CC.TypeCast(mstype.float32)
|
||||
type_cast1 = CC.TypeCast(mstype.float16)
|
||||
type_cast2 = CC.TypeCast(mstype.int32)
|
||||
type_cast3 = CC.TypeCast(mstype.bool_)
|
||||
|
@ -453,13 +454,15 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
|
|||
|
||||
flip = (np.random.rand() < config.flip_ratio)
|
||||
if flip:
|
||||
ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1],
|
||||
num_parallel_workers=24)
|
||||
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0, horizontally_op],
|
||||
num_parallel_workers=12)
|
||||
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
|
||||
operations=flipped_generation, num_parallel_workers=num_parallel_workers)
|
||||
else:
|
||||
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
|
||||
num_parallel_workers=24)
|
||||
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
|
||||
num_parallel_workers=12)
|
||||
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1],
|
||||
num_parallel_workers=12)
|
||||
|
||||
else:
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
|
|
Loading…
Reference in New Issue