!39394 Add new wrap functions for auto mixed precision strategy

Merge pull request !39394 from Bert0108/new_amp_interface
This commit is contained in:
i-robot 2022-08-05 08:34:06 +00:00 committed by Gitee
commit e59582f242
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 91 additions and 9 deletions

View File

@ -27,6 +27,24 @@ from .. import boost
from .. import context
amp_white_list = (
nn.Dense,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Conv1dTranspose,
nn.Conv2dTranspose,
nn.Conv3dTranspose
)
amp_black_list = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm
)
class _OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16"
@ -38,6 +56,72 @@ class _OutputTo16(nn.Cell):
return F.cast(self._op(x), mstype.float16)
class _OutputTo32(nn.Cell):
"Wrap cell for amp. Cast network output back to float32"
def __init__(self, op):
super(_OutputTo32, self).__init__(auto_prefix=False)
self._op = op
def construct(self, x):
return F.cast(self._op(x), mstype.float32)
def _auto_white_list(network, white_list=None):
"""process the white list of network."""
if white_list is None:
white_list = amp_white_list
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, white_list):
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16))
change = True
else:
_auto_white_list(subcell, white_list)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
def _auto_black_list(network, black_list=None, keep_norm_fp32=False):
"""process the black list of network."""
if black_list is None:
black_list = amp_black_list
network.to_float(mstype.float16)
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, black_list) and keep_norm_fp32:
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
change = True
else:
_auto_black_list(subcell, black_list, keep_norm_fp32)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
def auto_mixed_precision(network, amp_level="O0", keep_norm_fp32=False):
"""auto mixed precision function."""
if amp_level == "O0":
pass
elif amp_level == "O1":
_auto_white_list(network)
elif amp_level == "O2":
_auto_black_list(network, keep_norm_fp32=keep_norm_fp32)
elif amp_level == "O3":
network.to_float(mstype.float16)
if keep_norm_fp32:
_do_keep_batchnorm_fp32(network)
else:
raise ValueError("The amp level {} is not supported".format(amp_level))
def _do_keep_batchnorm_fp32(network):
"""Do keep batchnorm fp32."""
cells = network.name_cells()
@ -46,7 +130,7 @@ def _do_keep_batchnorm_fp32(network):
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
elif isinstance(subcell, amp_black_list):
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
change = True
else:
@ -91,9 +175,9 @@ def _check_kwargs(key_words):
def _check_level(level, boost_level):
"""Check level."""
if not isinstance(level, str):
raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
but got type {}.".format(type(level)))
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], Rel.IN)
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)
if level == "auto":
@ -145,9 +229,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Define the optimizer to update the Parameter.
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
- "O0": Do not change.
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
- "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
- "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
@ -191,11 +276,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
_check_kwargs(kwargs)
config = dict(_config_level.get(level), **kwargs)
if config["cast_model_type"] == mstype.float16:
network.to_float(mstype.float16)
if config["keep_batchnorm_fp32"]:
_do_keep_batchnorm_fp32(network)
auto_mixed_precision(network, level, config["keep_batchnorm_fp32"])
if loss_fn:
network = _add_loss_network(network, loss_fn, config["cast_model_type"])

View File

@ -147,6 +147,7 @@ class Model:
training. Supports ["O0", "O1", "O2"]. Default: "O0".
- "O0": Do not change.
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
- "O1": Enable the boost mode, the performance is improved by about 20%, and
the accuracy is the same as the original accuracy.
- "O2": Enable the boost mode, the performance is improved by about 30%, and