diff --git a/mindspore/python/mindspore/train/amp.py b/mindspore/python/mindspore/train/amp.py index f01d551e445..a12a6d55ec2 100644 --- a/mindspore/python/mindspore/train/amp.py +++ b/mindspore/python/mindspore/train/amp.py @@ -27,6 +27,24 @@ from .. import boost from .. import context +amp_white_list = ( + nn.Dense, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.Conv1dTranspose, + nn.Conv2dTranspose, + nn.Conv3dTranspose +) + +amp_black_list = ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LayerNorm +) + + class _OutputTo16(nn.Cell): "Wrap cell for amp. Cast network output back to float16" @@ -38,6 +56,72 @@ class _OutputTo16(nn.Cell): return F.cast(self._op(x), mstype.float16) +class _OutputTo32(nn.Cell): + "Wrap cell for amp. Cast network output back to float32" + + def __init__(self, op): + super(_OutputTo32, self).__init__(auto_prefix=False) + self._op = op + + def construct(self, x): + return F.cast(self._op(x), mstype.float32) + + +def _auto_white_list(network, white_list=None): + """process the white list of network.""" + if white_list is None: + white_list = amp_white_list + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, white_list): + network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16)) + change = True + else: + _auto_white_list(subcell, white_list) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + +def _auto_black_list(network, black_list=None, keep_norm_fp32=False): + """process the black list of network.""" + if black_list is None: + black_list = amp_black_list + network.to_float(mstype.float16) + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, black_list) and keep_norm_fp32: + network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32)) + change = True + else: + _auto_black_list(subcell, black_list, keep_norm_fp32) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + +def auto_mixed_precision(network, amp_level="O0", keep_norm_fp32=False): + """auto mixed precision function.""" + if amp_level == "O0": + pass + elif amp_level == "O1": + _auto_white_list(network) + elif amp_level == "O2": + _auto_black_list(network, keep_norm_fp32=keep_norm_fp32) + elif amp_level == "O3": + network.to_float(mstype.float16) + if keep_norm_fp32: + _do_keep_batchnorm_fp32(network) + else: + raise ValueError("The amp level {} is not supported".format(amp_level)) + + def _do_keep_batchnorm_fp32(network): """Do keep batchnorm fp32.""" cells = network.name_cells() @@ -46,7 +130,7 @@ def _do_keep_batchnorm_fp32(network): subcell = cells[name] if subcell == network: continue - elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): + elif isinstance(subcell, amp_black_list): network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32)) change = True else: @@ -91,9 +175,9 @@ def _check_kwargs(key_words): def _check_level(level, boost_level): """Check level.""" if not isinstance(level, str): - raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \ + raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \ but got type {}.".format(type(level))) - validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN) + validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], Rel.IN) validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN) if level == "auto": @@ -145,9 +229,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside. Default: None. optimizer (Optimizer): Define the optimizer to update the Parameter. - level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0". + level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0". - "O0": Do not change. + - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32. - "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, using dynamic loss scale. - "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` . @@ -191,11 +276,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve _check_kwargs(kwargs) config = dict(_config_level.get(level), **kwargs) - if config["cast_model_type"] == mstype.float16: - network.to_float(mstype.float16) - - if config["keep_batchnorm_fp32"]: - _do_keep_batchnorm_fp32(network) + auto_mixed_precision(network, level, config["keep_batchnorm_fp32"]) if loss_fn: network = _add_loss_network(network, loss_fn, config["cast_model_type"]) diff --git a/mindspore/python/mindspore/train/model.py b/mindspore/python/mindspore/train/model.py index 31212ff7469..9921996f765 100644 --- a/mindspore/python/mindspore/train/model.py +++ b/mindspore/python/mindspore/train/model.py @@ -147,6 +147,7 @@ class Model: training. Supports ["O0", "O1", "O2"]. Default: "O0". - "O0": Do not change. + - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32. - "O1": Enable the boost mode, the performance is improved by about 20%, and the accuracy is the same as the original accuracy. - "O2": Enable the boost mode, the performance is improved by about 30%, and