From 9ef6e72c8cd307b26b097cfc6f51d0be1686aea1 Mon Sep 17 00:00:00 2001
From: zhaoting <zhaoting23@huawei.com>
Date: Tue, 25 Aug 2020 17:09:25 +0800
Subject: [PATCH] change group conv dtype in gpu resnext50

---
 .../official/cv/resnext50/src/utils/auto_mixed_precision.py  | 3 ---
 model_zoo/official/cv/resnext50/train.py                     | 5 ++---
 2 files changed, 2 insertions(+), 6 deletions(-)

diff --git a/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py
index f8e27f5b522..18895f2b336 100644
--- a/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py
+++ b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py
@@ -44,9 +44,6 @@ def auto_mixed_precision(network):
         elif name == 'fc':
             network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
             change = True
-        elif name == 'conv2':
-            subcell.to_float(mstype.float32)
-            change = True
         elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
             network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
             change = True
diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py
index 6b0eaae03bf..d2cb72d5d21 100644
--- a/model_zoo/official/cv/resnext50/train.py
+++ b/model_zoo/official/cv/resnext50/train.py
@@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
 from src.utils.logging import get_logger
 from src.utils.optimizers__init__ import get_param_groups
 from src.image_classification import get_network
-from src.utils.auto_mixed_precision import auto_mixed_precision
 from src.config import config
 
 
@@ -273,8 +272,8 @@ def train(cloud_args=None):
         model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
                       metrics={'acc'}, amp_level="O3")
     else:
-        auto_mixed_precision(network)
-        model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
+        model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
+                      metrics={'acc'}, amp_level="O2")
 
     # checkpoint save
     progress_cb = ProgressMonitor(args)