From 9ef6e72c8cd307b26b097cfc6f51d0be1686aea1 Mon Sep 17 00:00:00 2001 From: zhaoting <zhaoting23@huawei.com> Date: Tue, 25 Aug 2020 17:09:25 +0800 Subject: [PATCH] change group conv dtype in gpu resnext50 --- .../official/cv/resnext50/src/utils/auto_mixed_precision.py | 3 --- model_zoo/official/cv/resnext50/train.py | 5 ++--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py index f8e27f5b522..18895f2b336 100644 --- a/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py +++ b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py @@ -44,9 +44,6 @@ def auto_mixed_precision(network): elif name == 'fc': network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32)) change = True - elif name == 'conv2': - subcell.to_float(mstype.float32) - change = True elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16)) change = True diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py index 6b0eaae03bf..d2cb72d5d21 100644 --- a/model_zoo/official/cv/resnext50/train.py +++ b/model_zoo/official/cv/resnext50/train.py @@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr from src.utils.logging import get_logger from src.utils.optimizers__init__ import get_param_groups from src.image_classification import get_network -from src.utils.auto_mixed_precision import auto_mixed_precision from src.config import config @@ -273,8 +272,8 @@ def train(cloud_args=None): model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'}, amp_level="O3") else: - auto_mixed_precision(network) - model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'}) + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, + metrics={'acc'}, amp_level="O2") # checkpoint save progress_cb = ProgressMonitor(args)