!39394 Add new wrap functions for auto mixed precision strategy
Merge pull request !39394 from Bert0108/new_amp_interface
This commit is contained in:
commit
e59582f242
|
@ -27,6 +27,24 @@ from .. import boost
|
||||||
from .. import context
|
from .. import context
|
||||||
|
|
||||||
|
|
||||||
|
amp_white_list = (
|
||||||
|
nn.Dense,
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.Conv1dTranspose,
|
||||||
|
nn.Conv2dTranspose,
|
||||||
|
nn.Conv3dTranspose
|
||||||
|
)
|
||||||
|
|
||||||
|
amp_black_list = (
|
||||||
|
nn.BatchNorm1d,
|
||||||
|
nn.BatchNorm2d,
|
||||||
|
nn.BatchNorm3d,
|
||||||
|
nn.LayerNorm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _OutputTo16(nn.Cell):
|
class _OutputTo16(nn.Cell):
|
||||||
"Wrap cell for amp. Cast network output back to float16"
|
"Wrap cell for amp. Cast network output back to float16"
|
||||||
|
|
||||||
|
@ -38,6 +56,72 @@ class _OutputTo16(nn.Cell):
|
||||||
return F.cast(self._op(x), mstype.float16)
|
return F.cast(self._op(x), mstype.float16)
|
||||||
|
|
||||||
|
|
||||||
|
class _OutputTo32(nn.Cell):
|
||||||
|
"Wrap cell for amp. Cast network output back to float32"
|
||||||
|
|
||||||
|
def __init__(self, op):
|
||||||
|
super(_OutputTo32, self).__init__(auto_prefix=False)
|
||||||
|
self._op = op
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return F.cast(self._op(x), mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_white_list(network, white_list=None):
|
||||||
|
"""process the white list of network."""
|
||||||
|
if white_list is None:
|
||||||
|
white_list = amp_white_list
|
||||||
|
cells = network.name_cells()
|
||||||
|
change = False
|
||||||
|
for name in cells:
|
||||||
|
subcell = cells[name]
|
||||||
|
if subcell == network:
|
||||||
|
continue
|
||||||
|
elif isinstance(subcell, white_list):
|
||||||
|
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16))
|
||||||
|
change = True
|
||||||
|
else:
|
||||||
|
_auto_white_list(subcell, white_list)
|
||||||
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
|
network.cell_list = list(network.cells())
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_black_list(network, black_list=None, keep_norm_fp32=False):
|
||||||
|
"""process the black list of network."""
|
||||||
|
if black_list is None:
|
||||||
|
black_list = amp_black_list
|
||||||
|
network.to_float(mstype.float16)
|
||||||
|
cells = network.name_cells()
|
||||||
|
change = False
|
||||||
|
for name in cells:
|
||||||
|
subcell = cells[name]
|
||||||
|
if subcell == network:
|
||||||
|
continue
|
||||||
|
elif isinstance(subcell, black_list) and keep_norm_fp32:
|
||||||
|
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
||||||
|
change = True
|
||||||
|
else:
|
||||||
|
_auto_black_list(subcell, black_list, keep_norm_fp32)
|
||||||
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
|
network.cell_list = list(network.cells())
|
||||||
|
|
||||||
|
|
||||||
|
def auto_mixed_precision(network, amp_level="O0", keep_norm_fp32=False):
|
||||||
|
"""auto mixed precision function."""
|
||||||
|
if amp_level == "O0":
|
||||||
|
pass
|
||||||
|
elif amp_level == "O1":
|
||||||
|
_auto_white_list(network)
|
||||||
|
elif amp_level == "O2":
|
||||||
|
_auto_black_list(network, keep_norm_fp32=keep_norm_fp32)
|
||||||
|
elif amp_level == "O3":
|
||||||
|
network.to_float(mstype.float16)
|
||||||
|
if keep_norm_fp32:
|
||||||
|
_do_keep_batchnorm_fp32(network)
|
||||||
|
else:
|
||||||
|
raise ValueError("The amp level {} is not supported".format(amp_level))
|
||||||
|
|
||||||
|
|
||||||
def _do_keep_batchnorm_fp32(network):
|
def _do_keep_batchnorm_fp32(network):
|
||||||
"""Do keep batchnorm fp32."""
|
"""Do keep batchnorm fp32."""
|
||||||
cells = network.name_cells()
|
cells = network.name_cells()
|
||||||
|
@ -46,7 +130,7 @@ def _do_keep_batchnorm_fp32(network):
|
||||||
subcell = cells[name]
|
subcell = cells[name]
|
||||||
if subcell == network:
|
if subcell == network:
|
||||||
continue
|
continue
|
||||||
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
elif isinstance(subcell, amp_black_list):
|
||||||
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
||||||
change = True
|
change = True
|
||||||
else:
|
else:
|
||||||
|
@ -91,9 +175,9 @@ def _check_kwargs(key_words):
|
||||||
def _check_level(level, boost_level):
|
def _check_level(level, boost_level):
|
||||||
"""Check level."""
|
"""Check level."""
|
||||||
if not isinstance(level, str):
|
if not isinstance(level, str):
|
||||||
raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
|
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
|
||||||
but got type {}.".format(type(level)))
|
but got type {}.".format(type(level)))
|
||||||
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
|
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], Rel.IN)
|
||||||
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)
|
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)
|
||||||
|
|
||||||
if level == "auto":
|
if level == "auto":
|
||||||
|
@ -145,9 +229,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
||||||
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
|
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
|
||||||
Default: None.
|
Default: None.
|
||||||
optimizer (Optimizer): Define the optimizer to update the Parameter.
|
optimizer (Optimizer): Define the optimizer to update the Parameter.
|
||||||
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
|
level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
|
||||||
|
|
||||||
- "O0": Do not change.
|
- "O0": Do not change.
|
||||||
|
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
||||||
- "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
- "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
||||||
using dynamic loss scale.
|
using dynamic loss scale.
|
||||||
- "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
- "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
||||||
|
@ -191,11 +276,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
||||||
_check_kwargs(kwargs)
|
_check_kwargs(kwargs)
|
||||||
config = dict(_config_level.get(level), **kwargs)
|
config = dict(_config_level.get(level), **kwargs)
|
||||||
|
|
||||||
if config["cast_model_type"] == mstype.float16:
|
auto_mixed_precision(network, level, config["keep_batchnorm_fp32"])
|
||||||
network.to_float(mstype.float16)
|
|
||||||
|
|
||||||
if config["keep_batchnorm_fp32"]:
|
|
||||||
_do_keep_batchnorm_fp32(network)
|
|
||||||
|
|
||||||
if loss_fn:
|
if loss_fn:
|
||||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
||||||
|
|
|
@ -147,6 +147,7 @@ class Model:
|
||||||
training. Supports ["O0", "O1", "O2"]. Default: "O0".
|
training. Supports ["O0", "O1", "O2"]. Default: "O0".
|
||||||
|
|
||||||
- "O0": Do not change.
|
- "O0": Do not change.
|
||||||
|
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
||||||
- "O1": Enable the boost mode, the performance is improved by about 20%, and
|
- "O1": Enable the boost mode, the performance is improved by about 20%, and
|
||||||
the accuracy is the same as the original accuracy.
|
the accuracy is the same as the original accuracy.
|
||||||
- "O2": Enable the boost mode, the performance is improved by about 30%, and
|
- "O2": Enable the boost mode, the performance is improved by about 30%, and
|
||||||
|
|
Loading…
Reference in New Issue