change group conv dtype in gpu resnext50

This commit is contained in:
zhaoting 2020-08-25 17:09:25 +08:00
parent 3a16925fa2
commit 9ef6e72c8c
2 changed files with 2 additions and 6 deletions

View File

@ -44,9 +44,6 @@ def auto_mixed_precision(network):
elif name == 'fc':
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
change = True
elif name == 'conv2':
subcell.to_float(mstype.float32)
change = True
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
change = True

View File

@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.utils.logging import get_logger
from src.utils.optimizers__init__ import get_param_groups
from src.image_classification import get_network
from src.utils.auto_mixed_precision import auto_mixed_precision
from src.config import config
@ -273,8 +272,8 @@ def train(cloud_args=None):
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O3")
else:
auto_mixed_precision(network)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O2")
# checkpoint save
progress_cb = ProgressMonitor(args)