!39394 Add new wrap functions for auto mixed precision strategy
Merge pull request !39394 from Bert0108/new_amp_interface
This commit is contained in:
commit
e59582f242
|
@ -27,6 +27,24 @@ from .. import boost
|
|||
from .. import context
|
||||
|
||||
|
||||
amp_white_list = (
|
||||
nn.Dense,
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.Conv1dTranspose,
|
||||
nn.Conv2dTranspose,
|
||||
nn.Conv3dTranspose
|
||||
)
|
||||
|
||||
amp_black_list = (
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.LayerNorm
|
||||
)
|
||||
|
||||
|
||||
class _OutputTo16(nn.Cell):
|
||||
"Wrap cell for amp. Cast network output back to float16"
|
||||
|
||||
|
@ -38,6 +56,72 @@ class _OutputTo16(nn.Cell):
|
|||
return F.cast(self._op(x), mstype.float16)
|
||||
|
||||
|
||||
class _OutputTo32(nn.Cell):
|
||||
"Wrap cell for amp. Cast network output back to float32"
|
||||
|
||||
def __init__(self, op):
|
||||
super(_OutputTo32, self).__init__(auto_prefix=False)
|
||||
self._op = op
|
||||
|
||||
def construct(self, x):
|
||||
return F.cast(self._op(x), mstype.float32)
|
||||
|
||||
|
||||
def _auto_white_list(network, white_list=None):
|
||||
"""process the white list of network."""
|
||||
if white_list is None:
|
||||
white_list = amp_white_list
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, white_list):
|
||||
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16))
|
||||
change = True
|
||||
else:
|
||||
_auto_white_list(subcell, white_list)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
|
||||
|
||||
def _auto_black_list(network, black_list=None, keep_norm_fp32=False):
|
||||
"""process the black list of network."""
|
||||
if black_list is None:
|
||||
black_list = amp_black_list
|
||||
network.to_float(mstype.float16)
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, black_list) and keep_norm_fp32:
|
||||
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
||||
change = True
|
||||
else:
|
||||
_auto_black_list(subcell, black_list, keep_norm_fp32)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
|
||||
|
||||
def auto_mixed_precision(network, amp_level="O0", keep_norm_fp32=False):
|
||||
"""auto mixed precision function."""
|
||||
if amp_level == "O0":
|
||||
pass
|
||||
elif amp_level == "O1":
|
||||
_auto_white_list(network)
|
||||
elif amp_level == "O2":
|
||||
_auto_black_list(network, keep_norm_fp32=keep_norm_fp32)
|
||||
elif amp_level == "O3":
|
||||
network.to_float(mstype.float16)
|
||||
if keep_norm_fp32:
|
||||
_do_keep_batchnorm_fp32(network)
|
||||
else:
|
||||
raise ValueError("The amp level {} is not supported".format(amp_level))
|
||||
|
||||
|
||||
def _do_keep_batchnorm_fp32(network):
|
||||
"""Do keep batchnorm fp32."""
|
||||
cells = network.name_cells()
|
||||
|
@ -46,7 +130,7 @@ def _do_keep_batchnorm_fp32(network):
|
|||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
elif isinstance(subcell, amp_black_list):
|
||||
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
||||
change = True
|
||||
else:
|
||||
|
@ -91,9 +175,9 @@ def _check_kwargs(key_words):
|
|||
def _check_level(level, boost_level):
|
||||
"""Check level."""
|
||||
if not isinstance(level, str):
|
||||
raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
|
||||
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
|
||||
but got type {}.".format(type(level)))
|
||||
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
|
||||
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], Rel.IN)
|
||||
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)
|
||||
|
||||
if level == "auto":
|
||||
|
@ -145,9 +229,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
|
||||
Default: None.
|
||||
optimizer (Optimizer): Define the optimizer to update the Parameter.
|
||||
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
|
||||
level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
|
||||
|
||||
- "O0": Do not change.
|
||||
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
||||
- "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
||||
using dynamic loss scale.
|
||||
- "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
||||
|
@ -191,11 +276,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
_check_kwargs(kwargs)
|
||||
config = dict(_config_level.get(level), **kwargs)
|
||||
|
||||
if config["cast_model_type"] == mstype.float16:
|
||||
network.to_float(mstype.float16)
|
||||
|
||||
if config["keep_batchnorm_fp32"]:
|
||||
_do_keep_batchnorm_fp32(network)
|
||||
auto_mixed_precision(network, level, config["keep_batchnorm_fp32"])
|
||||
|
||||
if loss_fn:
|
||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
||||
|
|
|
@ -147,6 +147,7 @@ class Model:
|
|||
training. Supports ["O0", "O1", "O2"]. Default: "O0".
|
||||
|
||||
- "O0": Do not change.
|
||||
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
||||
- "O1": Enable the boost mode, the performance is improved by about 20%, and
|
||||
the accuracy is the same as the original accuracy.
|
||||
- "O2": Enable the boost mode, the performance is improved by about 30%, and
|
||||
|
|
Loading…
Reference in New Issue