forked from mindspore-Ecosystem/mindspore
!49041 混合精度增加用户自定义黑白名单接口
Merge pull request !49041 from GuoZhibin/add_support_of_amp_custom_list
This commit is contained in:
commit
14a4a97270
|
@ -27,12 +27,14 @@ from mindspore.parallel._utils import _get_pipeline_stages
|
|||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||
from mindspore import boost, context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
STREE = None
|
||||
|
||||
|
||||
AMP_WHITE_LIST_Cell = (
|
||||
AMP_WHITE_LIST = [
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
|
@ -42,11 +44,7 @@ AMP_WHITE_LIST_Cell = (
|
|||
nn.Dense,
|
||||
nn.LSTMCell,
|
||||
nn.RNNCell,
|
||||
nn.GRUCell
|
||||
)
|
||||
|
||||
|
||||
AMP_WHITE_LIST_OPS = (
|
||||
nn.GRUCell,
|
||||
P.Conv2D,
|
||||
P.Conv3D,
|
||||
P.Conv2DTranspose,
|
||||
|
@ -57,15 +55,15 @@ AMP_WHITE_LIST_OPS = (
|
|||
P.PReLU,
|
||||
P.ReLU,
|
||||
P.Ger
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
AMP_BLACK_LIST = (
|
||||
AMP_BLACK_LIST = [
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.LayerNorm
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class _OutputTo16(nn.Cell):
|
||||
|
@ -93,20 +91,25 @@ class _OutputTo32(nn.Cell):
|
|||
return F.mixed_precision_cast(mstype.float32, out)
|
||||
|
||||
|
||||
def _insert_cast_operator(stree):
|
||||
def _insert_cast_operator(stree, white_list):
|
||||
"""insert cast for operators in white_list."""
|
||||
new_cast_node = None
|
||||
for node in stree.nodes():
|
||||
if node.get_targets() is None:
|
||||
continue
|
||||
in_white_list = False
|
||||
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
||||
for n in node.get_handler().node_list:
|
||||
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
||||
_insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
||||
elif node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
_insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), white_list)
|
||||
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
||||
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
||||
_insert_cast_operator(substree, white_list)
|
||||
else:
|
||||
if node.get_instance_type() not in white_list:
|
||||
continue
|
||||
in_white_list = False
|
||||
# insert cast before the primitive operators in white_list
|
||||
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
|
||||
if issubclass(node.get_instance_type(), Primitive):
|
||||
in_white_list = True
|
||||
for idx in range(len(node.get_inputs())):
|
||||
position = stree.before(node)
|
||||
|
@ -120,7 +123,7 @@ def _insert_cast_operator(stree):
|
|||
stree.insert(position, new_cast_node)
|
||||
node.set_arg_by_node(idx, new_cast_node)
|
||||
# insert cast before the Cell operators in white_list
|
||||
elif node.get_instance_type() in AMP_WHITE_LIST_Cell:
|
||||
elif issubclass(node.get_instance_type(), nn.Cell):
|
||||
in_white_list = True
|
||||
node.get_instance().to_float(mstype.float16)
|
||||
|
||||
|
@ -139,9 +142,6 @@ def _insert_cast_operator(stree):
|
|||
stree.insert(position, new_cast_node)
|
||||
idx = follow_node.get_args().index(node.get_targets()[0])
|
||||
follow_node.set_arg_by_node(idx, new_cast_node)
|
||||
else:
|
||||
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
||||
_insert_cast_operator(substree)
|
||||
|
||||
|
||||
def _removed_cast_pair(node):
|
||||
|
@ -190,19 +190,17 @@ def _remove_duplicated_cast(stree):
|
|||
_remove_duplicated_cast(substree)
|
||||
|
||||
|
||||
def _auto_white_list(network):
|
||||
def _auto_white_list(network, white_list):
|
||||
"""process the white list of network."""
|
||||
global STREE
|
||||
STREE = ms.rewrite.SymbolTree.create(network)
|
||||
_insert_cast_operator(STREE)
|
||||
_insert_cast_operator(STREE, white_list)
|
||||
_remove_duplicated_cast(STREE)
|
||||
return STREE.get_network()
|
||||
|
||||
|
||||
def _auto_black_list(network, black_list=None):
|
||||
def _auto_black_list(network, black_list):
|
||||
"""process the black list of network."""
|
||||
if black_list is None:
|
||||
black_list = AMP_BLACK_LIST
|
||||
network.to_float(mstype.float16)
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
|
@ -210,7 +208,7 @@ def _auto_black_list(network, black_list=None):
|
|||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
if isinstance(subcell, black_list):
|
||||
if isinstance(subcell, tuple(black_list)):
|
||||
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
||||
change = True
|
||||
else:
|
||||
|
@ -219,7 +217,83 @@ def _auto_black_list(network, black_list=None):
|
|||
network.cell_list = list(network.cells())
|
||||
|
||||
|
||||
def auto_mixed_precision(network, amp_level="O0"):
|
||||
def _custom_list_check(custom_list: {str: list}):
|
||||
"""
|
||||
check whether custom_list is valid
|
||||
|
||||
Raises:
|
||||
TypeError: The type of parameter custom_list is not dict.
|
||||
TypeError: The type of key in custom_list is not string.
|
||||
TypeError: The type of value in custom_list is not list.
|
||||
TypeError: The subclass of value in white_list is not one of ['Cell', 'Primitive'].
|
||||
TypeError: The subclass of value in black_list is not one of ['Cell', 'Primitive'].
|
||||
ValueError: The key in custom_list is not one of ['white_list', 'black_list'].
|
||||
ValueError: The white list and the black list have the same element.
|
||||
"""
|
||||
if custom_list is None:
|
||||
return
|
||||
|
||||
if not isinstance(custom_list, dict):
|
||||
raise TypeError(f"The type of parameter custom_list should be dict, but got {type(custom_list)}")
|
||||
|
||||
for key, value in custom_list.items():
|
||||
if key not in ("white_list", "black_list"):
|
||||
raise ValueError(f"The key in custom_list should be one of 'white_list' and 'black_list', but got {key}")
|
||||
if value is None:
|
||||
# internal list will be used if value is None
|
||||
continue
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"The type of value in custom_list should be list, but got {type(value)}")
|
||||
if key == "white_list":
|
||||
for elem in value:
|
||||
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
||||
raise TypeError(f"The subclass of value in white_list should be one of 'Cell' and 'Primitive', "
|
||||
f"but got {elem}")
|
||||
elif key == "black_list":
|
||||
for elem in value:
|
||||
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
||||
raise TypeError(f"The subclass of value in black_list should be one of 'Cell' and 'Primitive', "
|
||||
f"but got {elem}")
|
||||
if 'white_list' in custom_list and 'black_list' in custom_list:
|
||||
elem_intersction = list(set(custom_list['white_list']).intersection(custom_list['black_list']))
|
||||
if elem_intersction:
|
||||
raise ValueError(f"{elem_intersction} cannot be in white list and black list at the same time")
|
||||
|
||||
|
||||
def _get_amp_lists(amp_level, custom_list: {str: list}):
|
||||
""" get amp black list and white list from custom lists and default lists """
|
||||
white_list_updated = False
|
||||
black_list_updated = False
|
||||
white_list = AMP_WHITE_LIST
|
||||
black_list = AMP_BLACK_LIST
|
||||
|
||||
if custom_list is None:
|
||||
return white_list, black_list
|
||||
|
||||
if 'white_list' in custom_list and custom_list['white_list'] is not None:
|
||||
white_list = custom_list['white_list']
|
||||
white_list_updated = True
|
||||
if 'black_list' in custom_list and custom_list['black_list'] is not None:
|
||||
black_list = custom_list['black_list']
|
||||
black_list_updated = True
|
||||
|
||||
if amp_level in ('O0', 'O3'):
|
||||
if white_list_updated or black_list_updated:
|
||||
logger.warning(f"amp_level is {amp_level}, custom_list will not be used.")
|
||||
elif amp_level == 'O1' and black_list_updated:
|
||||
logger.warning(f"amp_level is {amp_level}, black_list in custom_list will not be used.")
|
||||
elif amp_level == 'O2':
|
||||
if white_list_updated:
|
||||
logger.warning(f"amp_level is {amp_level}, white_list in custom_list will not be used.")
|
||||
if black_list_updated:
|
||||
for elem in AMP_BLACK_LIST:
|
||||
if elem not in custom_list['black_list']:
|
||||
logger.warning(f"{elem} is removed from internal black list.")
|
||||
|
||||
return white_list, black_list
|
||||
|
||||
|
||||
def auto_mixed_precision(network, amp_level="O0", custom_list: {str: list}=None):
|
||||
"""
|
||||
auto mixed precision function.
|
||||
|
||||
|
@ -232,6 +306,20 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|||
- "O2": Cast network to float16, keep operators in black_list run in float32,
|
||||
- "O3": Cast network to float16.
|
||||
|
||||
custom_list (dict[str, list]): Use custom amp black list and amp white list instead of default lists.
|
||||
The type of key in custom_list is string, and supported keys are ["white_list", "black_list"].
|
||||
|
||||
- "white_list": The white list of auto mixed precision. Used when amp_level is O1.
|
||||
- "black_list": The black list of auto mixed precision. Used when amp_level is O2.
|
||||
|
||||
if custom_list is None, default white list and black list will be used.
|
||||
if custom_list["white_list"] is None, default white list will be used.
|
||||
if custom_list["black_list"] is None, default black list will be used.
|
||||
Format: custom_list = {"white_list":[Primitive or Cell], "black_list":[Primitive or Cell]}
|
||||
It is not recommended to delete members in the default black list.
|
||||
Default blacklist: [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm]
|
||||
Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If amp level is not supported.
|
||||
|
||||
|
@ -243,12 +331,16 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|||
"""
|
||||
if not isinstance(network, nn.Cell):
|
||||
raise TypeError("The network type should be Cell.")
|
||||
|
||||
_custom_list_check(custom_list)
|
||||
white_list, black_list = _get_amp_lists(amp_level, custom_list)
|
||||
|
||||
if amp_level == "O0":
|
||||
pass
|
||||
elif amp_level == "O1":
|
||||
return _auto_white_list(network)
|
||||
return _auto_white_list(network, white_list)
|
||||
elif amp_level == "O2":
|
||||
_auto_black_list(network)
|
||||
_auto_black_list(network, black_list)
|
||||
elif amp_level == "O3":
|
||||
network.to_float(mstype.float16)
|
||||
else:
|
||||
|
@ -266,7 +358,7 @@ def _do_keep_batchnorm_fp32(network):
|
|||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, AMP_BLACK_LIST):
|
||||
elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
|
||||
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
||||
change = True
|
||||
else:
|
||||
|
@ -361,7 +453,8 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|||
return network
|
||||
|
||||
|
||||
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
|
||||
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0',
|
||||
custom_list=None, **kwargs):
|
||||
"""
|
||||
Build the mixed precision training cell automatically.
|
||||
|
||||
|
@ -406,6 +499,19 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of
|
||||
:class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will
|
||||
take no effect on this property.
|
||||
custom_list (dict[str, list]): Use custom amp black list and amp white list instead of default lists.
|
||||
The type of key in custom_list is string, and supported keys are ["white_list", "black_list"].
|
||||
|
||||
- "white_list": The white list of auto mixed precision. Used when amp_level is O1.
|
||||
- "black_list": The black list of auto mixed precision. Used when amp_level is O2.
|
||||
|
||||
if custom_list is None, default white list and black list will be used.
|
||||
if custom_list["white_list"] is None, default white list will be used.
|
||||
if custom_list["black_list"] is None, default black list will be used.
|
||||
Format: custom_list = {"white_list":[Primitive or Cell], "black_list":[Primitive or Cell]}
|
||||
It is not recommended to delete members in the default black list.
|
||||
Default blacklist: [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm]
|
||||
Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
|
||||
|
@ -437,7 +543,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
||||
pass
|
||||
else:
|
||||
network = auto_mixed_precision(network, level)
|
||||
network = auto_mixed_precision(network, level, custom_list)
|
||||
|
||||
if loss_fn:
|
||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
||||
|
|
Loading…
Reference in New Issue