!49041 混合精度增加用户自定义黑白名单接口

Merge pull request !49041 from GuoZhibin/add_support_of_amp_custom_list
This commit is contained in:
i-robot 2023-03-03 10:10:36 +00:00 committed by Gitee
commit 14a4a97270
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 136 additions and 30 deletions

View File

@ -27,12 +27,14 @@ from mindspore.parallel._utils import _get_pipeline_stages
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from mindspore import boost, context
from mindspore.ops import operations as P
from mindspore.ops import Primitive
from mindspore import log as logger
STREE = None
AMP_WHITE_LIST_Cell = (
AMP_WHITE_LIST = [
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
@ -42,11 +44,7 @@ AMP_WHITE_LIST_Cell = (
nn.Dense,
nn.LSTMCell,
nn.RNNCell,
nn.GRUCell
)
AMP_WHITE_LIST_OPS = (
nn.GRUCell,
P.Conv2D,
P.Conv3D,
P.Conv2DTranspose,
@ -57,15 +55,15 @@ AMP_WHITE_LIST_OPS = (
P.PReLU,
P.ReLU,
P.Ger
)
]
AMP_BLACK_LIST = (
AMP_BLACK_LIST = [
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm
)
]
class _OutputTo16(nn.Cell):
@ -93,20 +91,25 @@ class _OutputTo32(nn.Cell):
return F.mixed_precision_cast(mstype.float32, out)
def _insert_cast_operator(stree):
def _insert_cast_operator(stree, white_list):
"""insert cast for operators in white_list."""
new_cast_node = None
for node in stree.nodes():
if node.get_targets() is None:
continue
in_white_list = False
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
for n in node.get_handler().node_list:
if n.get_node_type() == ms.rewrite.NodeType.Tree:
_insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
elif node.get_node_type() != ms.rewrite.NodeType.Tree:
_insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), white_list)
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
_insert_cast_operator(substree, white_list)
else:
if node.get_instance_type() not in white_list:
continue
in_white_list = False
# insert cast before the primitive operators in white_list
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
if issubclass(node.get_instance_type(), Primitive):
in_white_list = True
for idx in range(len(node.get_inputs())):
position = stree.before(node)
@ -120,7 +123,7 @@ def _insert_cast_operator(stree):
stree.insert(position, new_cast_node)
node.set_arg_by_node(idx, new_cast_node)
# insert cast before the Cell operators in white_list
elif node.get_instance_type() in AMP_WHITE_LIST_Cell:
elif issubclass(node.get_instance_type(), nn.Cell):
in_white_list = True
node.get_instance().to_float(mstype.float16)
@ -139,9 +142,6 @@ def _insert_cast_operator(stree):
stree.insert(position, new_cast_node)
idx = follow_node.get_args().index(node.get_targets()[0])
follow_node.set_arg_by_node(idx, new_cast_node)
else:
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
_insert_cast_operator(substree)
def _removed_cast_pair(node):
@ -190,19 +190,17 @@ def _remove_duplicated_cast(stree):
_remove_duplicated_cast(substree)
def _auto_white_list(network):
def _auto_white_list(network, white_list):
"""process the white list of network."""
global STREE
STREE = ms.rewrite.SymbolTree.create(network)
_insert_cast_operator(STREE)
_insert_cast_operator(STREE, white_list)
_remove_duplicated_cast(STREE)
return STREE.get_network()
def _auto_black_list(network, black_list=None):
def _auto_black_list(network, black_list):
"""process the black list of network."""
if black_list is None:
black_list = AMP_BLACK_LIST
network.to_float(mstype.float16)
cells = network.name_cells()
change = False
@ -210,7 +208,7 @@ def _auto_black_list(network, black_list=None):
subcell = cells[name]
if subcell == network:
continue
if isinstance(subcell, black_list):
if isinstance(subcell, tuple(black_list)):
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
change = True
else:
@ -219,7 +217,83 @@ def _auto_black_list(network, black_list=None):
network.cell_list = list(network.cells())
def auto_mixed_precision(network, amp_level="O0"):
def _custom_list_check(custom_list: {str: list}):
"""
check whether custom_list is valid
Raises:
TypeError: The type of parameter custom_list is not dict.
TypeError: The type of key in custom_list is not string.
TypeError: The type of value in custom_list is not list.
TypeError: The subclass of value in white_list is not one of ['Cell', 'Primitive'].
TypeError: The subclass of value in black_list is not one of ['Cell', 'Primitive'].
ValueError: The key in custom_list is not one of ['white_list', 'black_list'].
ValueError: The white list and the black list have the same element.
"""
if custom_list is None:
return
if not isinstance(custom_list, dict):
raise TypeError(f"The type of parameter custom_list should be dict, but got {type(custom_list)}")
for key, value in custom_list.items():
if key not in ("white_list", "black_list"):
raise ValueError(f"The key in custom_list should be one of 'white_list' and 'black_list', but got {key}")
if value is None:
# internal list will be used if value is None
continue
if not isinstance(value, list):
raise TypeError(f"The type of value in custom_list should be list, but got {type(value)}")
if key == "white_list":
for elem in value:
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
raise TypeError(f"The subclass of value in white_list should be one of 'Cell' and 'Primitive', "
f"but got {elem}")
elif key == "black_list":
for elem in value:
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
raise TypeError(f"The subclass of value in black_list should be one of 'Cell' and 'Primitive', "
f"but got {elem}")
if 'white_list' in custom_list and 'black_list' in custom_list:
elem_intersction = list(set(custom_list['white_list']).intersection(custom_list['black_list']))
if elem_intersction:
raise ValueError(f"{elem_intersction} cannot be in white list and black list at the same time")
def _get_amp_lists(amp_level, custom_list: {str: list}):
""" get amp black list and white list from custom lists and default lists """
white_list_updated = False
black_list_updated = False
white_list = AMP_WHITE_LIST
black_list = AMP_BLACK_LIST
if custom_list is None:
return white_list, black_list
if 'white_list' in custom_list and custom_list['white_list'] is not None:
white_list = custom_list['white_list']
white_list_updated = True
if 'black_list' in custom_list and custom_list['black_list'] is not None:
black_list = custom_list['black_list']
black_list_updated = True
if amp_level in ('O0', 'O3'):
if white_list_updated or black_list_updated:
logger.warning(f"amp_level is {amp_level}, custom_list will not be used.")
elif amp_level == 'O1' and black_list_updated:
logger.warning(f"amp_level is {amp_level}, black_list in custom_list will not be used.")
elif amp_level == 'O2':
if white_list_updated:
logger.warning(f"amp_level is {amp_level}, white_list in custom_list will not be used.")
if black_list_updated:
for elem in AMP_BLACK_LIST:
if elem not in custom_list['black_list']:
logger.warning(f"{elem} is removed from internal black list.")
return white_list, black_list
def auto_mixed_precision(network, amp_level="O0", custom_list: {str: list}=None):
"""
auto mixed precision function.
@ -232,6 +306,20 @@ def auto_mixed_precision(network, amp_level="O0"):
- "O2": Cast network to float16, keep operators in black_list run in float32,
- "O3": Cast network to float16.
custom_list (dict[str, list]): Use custom amp black list and amp white list instead of default lists.
The type of key in custom_list is string, and supported keys are ["white_list", "black_list"].
- "white_list": The white list of auto mixed precision. Used when amp_level is O1.
- "black_list": The black list of auto mixed precision. Used when amp_level is O2.
if custom_list is None, default white list and black list will be used.
if custom_list["white_list"] is None, default white list will be used.
if custom_list["black_list"] is None, default black list will be used.
Format: custom_list = {"white_list":[Primitive or Cell], "black_list":[Primitive or Cell]}
It is not recommended to delete members in the default black list.
Default blacklist: [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm]
Default: None.
Raises:
ValueError: If amp level is not supported.
@ -243,12 +331,16 @@ def auto_mixed_precision(network, amp_level="O0"):
"""
if not isinstance(network, nn.Cell):
raise TypeError("The network type should be Cell.")
_custom_list_check(custom_list)
white_list, black_list = _get_amp_lists(amp_level, custom_list)
if amp_level == "O0":
pass
elif amp_level == "O1":
return _auto_white_list(network)
return _auto_white_list(network, white_list)
elif amp_level == "O2":
_auto_black_list(network)
_auto_black_list(network, black_list)
elif amp_level == "O3":
network.to_float(mstype.float16)
else:
@ -266,7 +358,7 @@ def _do_keep_batchnorm_fp32(network):
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, AMP_BLACK_LIST):
elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
change = True
else:
@ -361,7 +453,8 @@ def _add_loss_network(network, loss_fn, cast_model_type):
return network
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0',
custom_list=None, **kwargs):
"""
Build the mixed precision training cell automatically.
@ -406,6 +499,19 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of
:class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will
take no effect on this property.
custom_list (dict[str, list]): Use custom amp black list and amp white list instead of default lists.
The type of key in custom_list is string, and supported keys are ["white_list", "black_list"].
- "white_list": The white list of auto mixed precision. Used when amp_level is O1.
- "black_list": The black list of auto mixed precision. Used when amp_level is O2.
if custom_list is None, default white list and black list will be used.
if custom_list["white_list"] is None, default white list will be used.
if custom_list["black_list"] is None, default black list will be used.
Format: custom_list = {"white_list":[Primitive or Cell], "black_list":[Primitive or Cell]}
It is not recommended to delete members in the default black list.
Default blacklist: [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm]
Default: None.
Raises:
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
@ -437,7 +543,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
pass
else:
network = auto_mixed_precision(network, level)
network = auto_mixed_precision(network, level, custom_list)
if loss_fn:
network = _add_loss_network(network, loss_fn, config["cast_model_type"])