add amp o1 level

This commit is contained in:
Bert0108 2022-09-08 16:48:31 +08:00 committed by YouhuiBai
parent 3324fea63f
commit 8e26ebdfbc
3 changed files with 172 additions and 25 deletions

View File

@ -29,6 +29,7 @@ from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_contex
from mindspore.version import __version__ from mindspore.version import __version__
from mindspore.profiler import Profiler from mindspore.profiler import Profiler
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper
__all__ = ["run_check"] __all__ = ["run_check"]
@ -38,4 +39,5 @@ __all__.extend(train.__all__)
__all__.extend(log.__all__) __all__.extend(log.__all__)
__all__.extend(context.__all__) __all__.extend(context.__all__)
__all__.extend(parallel.__all__) __all__.extend(parallel.__all__)
__all__.extend(rewrite.__all__)
__all__.append("Profiler") __all__.append("Profiler")

View File

@ -15,6 +15,7 @@
"""Auto mixed precision.""" """Auto mixed precision."""
from __future__ import absolute_import from __future__ import absolute_import
import mindspore as ms
from mindspore import nn from mindspore import nn
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
@ -25,18 +26,40 @@ from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_pipeline_stages from mindspore.parallel._utils import _get_pipeline_stages
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from mindspore import boost, context from mindspore import boost, context
from mindspore.ops import operations as P
AMP_WHITE_LIST = ( STREE = None
nn.Dense,
AMP_WHITE_LIST_Cell = (
nn.Conv1d, nn.Conv1d,
nn.Conv2d, nn.Conv2d,
nn.Conv3d, nn.Conv3d,
nn.Conv1dTranspose, nn.Conv1dTranspose,
nn.Conv2dTranspose, nn.Conv2dTranspose,
nn.Conv3dTranspose nn.Conv3dTranspose,
nn.Dense,
nn.LSTMCell,
nn.RNNCell,
nn.GRUCell
) )
AMP_WHITE_LIST_OPS = (
P.Conv2D,
P.Conv3D,
P.Conv2DTranspose,
P.Conv3DTranspose,
P.Conv2DBackpropInput,
P.MatMul,
P.BatchMatMul,
P.PReLU,
P.ReLU,
P.Ger
)
AMP_BLACK_LIST = ( AMP_BLACK_LIST = (
nn.BatchNorm1d, nn.BatchNorm1d,
nn.BatchNorm2d, nn.BatchNorm2d,
@ -67,23 +90,102 @@ class _OutputTo32(nn.Cell):
return F.cast(self._op(x), mstype.float32) return F.cast(self._op(x), mstype.float32)
def _auto_white_list(network, white_list=None): def _insert_cast_operator(stree):
"""process the white list of network.""" """insert cast for operators in white_list."""
if white_list is None: new_cast_node = None
white_list = AMP_WHITE_LIST for node in stree.nodes():
cells = network.name_cells() if node.get_targets() is None:
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue continue
if isinstance(subcell, white_list): in_white_list = False
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16)) if node.get_node_type() != ms.rewrite.NodeType.Tree:
change = True # insert cast before the primitive operators in white_list
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
in_white_list = True
for idx in range(len(node.get_inputs())):
position = stree.before(node)
new_node = P.Cast()
arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
"mindspore.float16"])
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
targets=['x_cast_{}'.format(node.get_name())],
args=arg,
name='incast_{}{}'.format(node.get_name(), idx))
stree.insert(position, new_cast_node)
node.set_arg_by_node(idx, new_cast_node)
# insert cast before the Cell operators in white_list
elif node.get_instance_type() in AMP_WHITE_LIST_Cell:
in_white_list = True
node.get_instance().to_float(mstype.float16)
# insert cast after the operators in white_list
if in_white_list:
position = stree.after(node)
new_node = P.Cast()
arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
"mindspore.float32"])
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
targets=['x_cast_{}'.format(node.get_name())],
args=arg,
name='outcast_{}'.format(node.get_name()))
for i in range(len(node.get_users())):
follow_node = node.get_users()[i]
stree.insert(position, new_cast_node)
idx = follow_node.get_args().index(node.get_targets()[0])
follow_node.set_arg_by_node(idx, new_cast_node)
else: else:
_auto_white_list(subcell, white_list) substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
if isinstance(network, nn.SequentialCell) and change: _insert_cast_operator(substree)
network.cell_list = list(network.cells())
def _removed_cast_pair(node):
"""the cast pairs should be removed."""
for i in range(len(node.get_users())):
follow_node = node.get_users()[i]
if follow_node.get_instance_type() != P.Cast:
return False
node_dtype = node.get_args()[1]
if len(node.get_users()).__trunc__() == 0:
return False
follow_node_dtype = node.get_users()[0].get_args()[1]
for i in range(1, len(node.get_users())):
dtype = node.get_users()[i].get_args()[1]
if dtype == follow_node_dtype:
continue
if i == len(node.get_users()) - 1 and follow_node_dtype != node_dtype:
return True
return False
def _remove_duplicated_cast(stree):
"""remove the duplicated cast operators."""
for node in stree.nodes():
if node.get_targets() is None:
continue
if node.get_node_type() != ms.rewrite.NodeType.Tree:
if node.get_instance_type() == P.Cast and _removed_cast_pair(node):
# remove the following cast node first
len_users = len(node.get_users())
for i in range(len_users):
follow_node = node.get_users()[i]
for n in follow_node.get_users():
idx = n.get_args().index(follow_node.get_targets()[0])
n.set_arg_by_node(idx, node.get_inputs()[0])
stree.erase_node(follow_node)
# remove the current cast node
stree.erase_node(node)
else:
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
_remove_duplicated_cast(substree)
def _auto_white_list(network):
"""process the white list of network."""
global STREE
STREE = ms.rewrite.SymbolTree.create(network)
_insert_cast_operator(STREE)
_remove_duplicated_cast(STREE)
return STREE.get_network()
def _auto_black_list(network, black_list=None): def _auto_black_list(network, black_list=None):
@ -125,13 +227,14 @@ def auto_mixed_precision(network, amp_level="O0"):
if amp_level == "O0": if amp_level == "O0":
pass pass
elif amp_level == "O1": elif amp_level == "O1":
_auto_white_list(network) return _auto_white_list(network)
elif amp_level == "O2": elif amp_level == "O2":
_auto_black_list(network) _auto_black_list(network)
elif amp_level == "O3": elif amp_level == "O3":
network.to_float(mstype.float16) network.to_float(mstype.float16)
else: else:
raise ValueError("The amp level {} is not supported".format(amp_level)) raise ValueError("The amp level {} is not supported".format(amp_level))
return network
def _do_keep_batchnorm_fp32(network): def _do_keep_batchnorm_fp32(network):
@ -214,7 +317,7 @@ def _check_level(level, boost_level):
return level, enable_boost return level, enable_boost
def _add_loss_network(network, loss_fn, cast_model_type): def _add_loss_network(network, loss_fn):
"""Add loss network.""" """Add loss network."""
class WithLossCell(nn.Cell): class WithLossCell(nn.Cell):
@ -233,10 +336,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
validator.check_value_type('loss_fn', loss_fn, nn.Cell) validator.check_value_type('loss_fn', loss_fn, nn.Cell)
if cast_model_type == mstype.float16: network = WithLossCell(network, loss_fn)
network = WithLossCell(network, loss_fn)
else:
network = nn.WithLossCell(network, loss_fn)
return network return network
@ -304,7 +404,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
auto_mixed_precision(network, level) auto_mixed_precision(network, level)
if loss_fn: if loss_fn:
network = _add_loss_network(network, loss_fn, config["cast_model_type"]) network = _add_loss_network(network, loss_fn)
loss_scale = 1.0 loss_scale = 1.0
if config["loss_scale_manager"] is not None: if config["loss_scale_manager"] is not None:

View File

@ -179,3 +179,48 @@ def test_sit_auto_mix_precision_model_o2():
model_pynative.train(1, dataset2, dataset_sink_mode=False) model_pynative.train(1, dataset2, dataset_sink_mode=False)
out_pynative = model_pynative.predict(Tensor(input_data)) out_pynative = model_pynative.predict(Tensor(input_data))
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001) allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@security_off_wrap
def test_sit_auto_mix_precision_model_o1():
"""
Feature: Test the O1 level auto mixed precision
Description: input O1 level to Model interface
Expectation: success.
"""
input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
dataset1 = FakeData(size=32,
batch_size=32,
image_size=(3, 224, 224),
num_classes=10,
fakedata_mode=FakeDataInitMode.OnesInit)
dataset2 = FakeData(size=32,
batch_size=32,
image_size=(3, 224, 224),
num_classes=10,
fakedata_mode=FakeDataInitMode.OnesInit)
# graph mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True, save_graphs_path='./test_amp_o1')
net = Net(3, 10)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
model = Model(net, loss, opt, amp_level="O1")
model.train(1, dataset1, dataset_sink_mode=False)
clean_all_ir_files('./test_amp_o1/')
out_graph = model.predict(Tensor(input_data))
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
net_pynative = Net(3, 10)
opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009)
loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O1")
model_pynative.train(1, dataset2, dataset_sink_mode=False)
out_pynative = model_pynative.predict(Tensor(input_data))
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)