forked from mindspore-Ecosystem/mindspore
change group conv dtype in gpu resnext50
This commit is contained in:
parent
3a16925fa2
commit
9ef6e72c8c
|
@ -44,9 +44,6 @@ def auto_mixed_precision(network):
|
|||
elif name == 'fc':
|
||||
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
|
||||
change = True
|
||||
elif name == 'conv2':
|
||||
subcell.to_float(mstype.float32)
|
||||
change = True
|
||||
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
|
||||
change = True
|
||||
|
|
|
@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
|
|||
from src.utils.logging import get_logger
|
||||
from src.utils.optimizers__init__ import get_param_groups
|
||||
from src.image_classification import get_network
|
||||
from src.utils.auto_mixed_precision import auto_mixed_precision
|
||||
from src.config import config
|
||||
|
||||
|
||||
|
@ -273,8 +272,8 @@ def train(cloud_args=None):
|
|||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O3")
|
||||
else:
|
||||
auto_mixed_precision(network)
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O2")
|
||||
|
||||
# checkpoint save
|
||||
progress_cb = ProgressMonitor(args)
|
||||
|
|
Loading…
Reference in New Issue