From 8e26ebdfbc7855ca2b8947a6a98b878c3edff3fc Mon Sep 17 00:00:00 2001 From: Bert0108 Date: Thu, 8 Sep 2022 16:48:31 +0800 Subject: [PATCH] add amp o1 level --- mindspore/python/mindspore/__init__.py | 2 + mindspore/python/mindspore/train/amp.py | 150 +++++++++++++++---- tests/st/mix_precision/test_mix_precision.py | 45 ++++++ 3 files changed, 172 insertions(+), 25 deletions(-) diff --git a/mindspore/python/mindspore/__init__.py b/mindspore/python/mindspore/__init__.py index 2909bf044bb..e67f0006ba2 100755 --- a/mindspore/python/mindspore/__init__.py +++ b/mindspore/python/mindspore/__init__.py @@ -29,6 +29,7 @@ from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_contex from mindspore.version import __version__ from mindspore.profiler import Profiler from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters +from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper __all__ = ["run_check"] @@ -38,4 +39,5 @@ __all__.extend(train.__all__) __all__.extend(log.__all__) __all__.extend(context.__all__) __all__.extend(parallel.__all__) +__all__.extend(rewrite.__all__) __all__.append("Profiler") diff --git a/mindspore/python/mindspore/train/amp.py b/mindspore/python/mindspore/train/amp.py index 21116b0dfcb..5a3b46eb07f 100644 --- a/mindspore/python/mindspore/train/amp.py +++ b/mindspore/python/mindspore/train/amp.py @@ -15,6 +15,7 @@ """Auto mixed precision.""" from __future__ import absolute_import +import mindspore as ms from mindspore import nn from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel @@ -25,18 +26,40 @@ from mindspore.ops import functional as F from mindspore.parallel._utils import _get_pipeline_stages from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager from mindspore import boost, context +from mindspore.ops import operations as P -AMP_WHITE_LIST = ( - nn.Dense, +STREE = None + + +AMP_WHITE_LIST_Cell = ( nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Conv1dTranspose, nn.Conv2dTranspose, - nn.Conv3dTranspose + nn.Conv3dTranspose, + nn.Dense, + nn.LSTMCell, + nn.RNNCell, + nn.GRUCell ) + +AMP_WHITE_LIST_OPS = ( + P.Conv2D, + P.Conv3D, + P.Conv2DTranspose, + P.Conv3DTranspose, + P.Conv2DBackpropInput, + P.MatMul, + P.BatchMatMul, + P.PReLU, + P.ReLU, + P.Ger +) + + AMP_BLACK_LIST = ( nn.BatchNorm1d, nn.BatchNorm2d, @@ -67,23 +90,102 @@ class _OutputTo32(nn.Cell): return F.cast(self._op(x), mstype.float32) -def _auto_white_list(network, white_list=None): - """process the white list of network.""" - if white_list is None: - white_list = AMP_WHITE_LIST - cells = network.name_cells() - change = False - for name in cells: - subcell = cells[name] - if subcell == network: +def _insert_cast_operator(stree): + """insert cast for operators in white_list.""" + new_cast_node = None + for node in stree.nodes(): + if node.get_targets() is None: continue - if isinstance(subcell, white_list): - network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16)) - change = True + in_white_list = False + if node.get_node_type() != ms.rewrite.NodeType.Tree: + # insert cast before the primitive operators in white_list + if node.get_instance_type() in AMP_WHITE_LIST_OPS: + in_white_list = True + for idx in range(len(node.get_inputs())): + position = stree.before(node) + new_node = P.Cast() + arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value, + "mindspore.float16"]) + new_cast_node = ms.rewrite.Node.create_call_cell(new_node, + targets=['x_cast_{}'.format(node.get_name())], + args=arg, + name='incast_{}{}'.format(node.get_name(), idx)) + stree.insert(position, new_cast_node) + node.set_arg_by_node(idx, new_cast_node) + # insert cast before the Cell operators in white_list + elif node.get_instance_type() in AMP_WHITE_LIST_Cell: + in_white_list = True + node.get_instance().to_float(mstype.float16) + + # insert cast after the operators in white_list + if in_white_list: + position = stree.after(node) + new_node = P.Cast() + arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value, + "mindspore.float32"]) + new_cast_node = ms.rewrite.Node.create_call_cell(new_node, + targets=['x_cast_{}'.format(node.get_name())], + args=arg, + name='outcast_{}'.format(node.get_name())) + for i in range(len(node.get_users())): + follow_node = node.get_users()[i] + stree.insert(position, new_cast_node) + idx = follow_node.get_args().index(node.get_targets()[0]) + follow_node.set_arg_by_node(idx, new_cast_node) else: - _auto_white_list(subcell, white_list) - if isinstance(network, nn.SequentialCell) and change: - network.cell_list = list(network.cells()) + substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node) + _insert_cast_operator(substree) + + +def _removed_cast_pair(node): + """the cast pairs should be removed.""" + for i in range(len(node.get_users())): + follow_node = node.get_users()[i] + if follow_node.get_instance_type() != P.Cast: + return False + node_dtype = node.get_args()[1] + if len(node.get_users()).__trunc__() == 0: + return False + follow_node_dtype = node.get_users()[0].get_args()[1] + for i in range(1, len(node.get_users())): + dtype = node.get_users()[i].get_args()[1] + if dtype == follow_node_dtype: + continue + if i == len(node.get_users()) - 1 and follow_node_dtype != node_dtype: + return True + + return False + + +def _remove_duplicated_cast(stree): + """remove the duplicated cast operators.""" + for node in stree.nodes(): + if node.get_targets() is None: + continue + if node.get_node_type() != ms.rewrite.NodeType.Tree: + if node.get_instance_type() == P.Cast and _removed_cast_pair(node): + # remove the following cast node first + len_users = len(node.get_users()) + for i in range(len_users): + follow_node = node.get_users()[i] + for n in follow_node.get_users(): + idx = n.get_args().index(follow_node.get_targets()[0]) + n.set_arg_by_node(idx, node.get_inputs()[0]) + stree.erase_node(follow_node) + # remove the current cast node + stree.erase_node(node) + else: + substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node) + _remove_duplicated_cast(substree) + + +def _auto_white_list(network): + """process the white list of network.""" + global STREE + STREE = ms.rewrite.SymbolTree.create(network) + _insert_cast_operator(STREE) + _remove_duplicated_cast(STREE) + return STREE.get_network() def _auto_black_list(network, black_list=None): @@ -125,13 +227,14 @@ def auto_mixed_precision(network, amp_level="O0"): if amp_level == "O0": pass elif amp_level == "O1": - _auto_white_list(network) + return _auto_white_list(network) elif amp_level == "O2": _auto_black_list(network) elif amp_level == "O3": network.to_float(mstype.float16) else: raise ValueError("The amp level {} is not supported".format(amp_level)) + return network def _do_keep_batchnorm_fp32(network): @@ -214,7 +317,7 @@ def _check_level(level, boost_level): return level, enable_boost -def _add_loss_network(network, loss_fn, cast_model_type): +def _add_loss_network(network, loss_fn): """Add loss network.""" class WithLossCell(nn.Cell): @@ -233,10 +336,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) validator.check_value_type('loss_fn', loss_fn, nn.Cell) - if cast_model_type == mstype.float16: - network = WithLossCell(network, loss_fn) - else: - network = nn.WithLossCell(network, loss_fn) + network = WithLossCell(network, loss_fn) return network @@ -304,7 +404,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve auto_mixed_precision(network, level) if loss_fn: - network = _add_loss_network(network, loss_fn, config["cast_model_type"]) + network = _add_loss_network(network, loss_fn) loss_scale = 1.0 if config["loss_scale_manager"] is not None: diff --git a/tests/st/mix_precision/test_mix_precision.py b/tests/st/mix_precision/test_mix_precision.py index 2114145da8b..1583e9345f1 100644 --- a/tests/st/mix_precision/test_mix_precision.py +++ b/tests/st/mix_precision/test_mix_precision.py @@ -179,3 +179,48 @@ def test_sit_auto_mix_precision_model_o2(): model_pynative.train(1, dataset2, dataset_sink_mode=False) out_pynative = model_pynative.predict(Tensor(input_data)) allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@security_off_wrap +def test_sit_auto_mix_precision_model_o1(): + """ + Feature: Test the O1 level auto mixed precision + Description: input O1 level to Model interface + Expectation: success. + """ + input_data = np.random.randn(32, 3, 224, 224).astype(np.float32) + dataset1 = FakeData(size=32, + batch_size=32, + image_size=(3, 224, 224), + num_classes=10, + fakedata_mode=FakeDataInitMode.OnesInit) + dataset2 = FakeData(size=32, + batch_size=32, + image_size=(3, 224, 224), + num_classes=10, + fakedata_mode=FakeDataInitMode.OnesInit) + # graph mode + context.set_context(mode=context.GRAPH_MODE) + context.set_context(save_graphs=True, save_graphs_path='./test_amp_o1') + net = Net(3, 10) + opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) + model = Model(net, loss, opt, amp_level="O1") + model.train(1, dataset1, dataset_sink_mode=False) + clean_all_ir_files('./test_amp_o1/') + out_graph = model.predict(Tensor(input_data)) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = Net(3, 10) + opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009) + loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False) + model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O1") + model_pynative.train(1, dataset2, dataset_sink_mode=False) + out_pynative = model_pynative.predict(Tensor(input_data)) + allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)