forked from mindspore-Ecosystem/mindspore
!41693 Add O1 level amp feature for MindSpore master branch
Merge pull request !41693 from Bert0108/amp_o1
This commit is contained in:
commit
90142f04fa
|
@ -29,6 +29,7 @@ from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_contex
|
|||
from mindspore.version import __version__
|
||||
from mindspore.profiler import Profiler
|
||||
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters
|
||||
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper
|
||||
|
||||
|
||||
__all__ = ["run_check"]
|
||||
|
@ -38,4 +39,5 @@ __all__.extend(train.__all__)
|
|||
__all__.extend(log.__all__)
|
||||
__all__.extend(context.__all__)
|
||||
__all__.extend(parallel.__all__)
|
||||
__all__.extend(rewrite.__all__)
|
||||
__all__.append("Profiler")
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Auto mixed precision."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import nn
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
@ -25,18 +26,40 @@ from mindspore.ops import functional as F
|
|||
from mindspore.parallel._utils import _get_pipeline_stages
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||
from mindspore import boost, context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
AMP_WHITE_LIST = (
|
||||
nn.Dense,
|
||||
STREE = None
|
||||
|
||||
|
||||
AMP_WHITE_LIST_Cell = (
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.Conv1dTranspose,
|
||||
nn.Conv2dTranspose,
|
||||
nn.Conv3dTranspose
|
||||
nn.Conv3dTranspose,
|
||||
nn.Dense,
|
||||
nn.LSTMCell,
|
||||
nn.RNNCell,
|
||||
nn.GRUCell
|
||||
)
|
||||
|
||||
|
||||
AMP_WHITE_LIST_OPS = (
|
||||
P.Conv2D,
|
||||
P.Conv3D,
|
||||
P.Conv2DTranspose,
|
||||
P.Conv3DTranspose,
|
||||
P.Conv2DBackpropInput,
|
||||
P.MatMul,
|
||||
P.BatchMatMul,
|
||||
P.PReLU,
|
||||
P.ReLU,
|
||||
P.Ger
|
||||
)
|
||||
|
||||
|
||||
AMP_BLACK_LIST = (
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
|
@ -67,23 +90,102 @@ class _OutputTo32(nn.Cell):
|
|||
return F.cast(self._op(x), mstype.float32)
|
||||
|
||||
|
||||
def _auto_white_list(network, white_list=None):
|
||||
"""process the white list of network."""
|
||||
if white_list is None:
|
||||
white_list = AMP_WHITE_LIST
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
def _insert_cast_operator(stree):
|
||||
"""insert cast for operators in white_list."""
|
||||
new_cast_node = None
|
||||
for node in stree.nodes():
|
||||
if node.get_targets() is None:
|
||||
continue
|
||||
if isinstance(subcell, white_list):
|
||||
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16))
|
||||
change = True
|
||||
in_white_list = False
|
||||
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
# insert cast before the primitive operators in white_list
|
||||
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
|
||||
in_white_list = True
|
||||
for idx in range(len(node.get_inputs())):
|
||||
position = stree.before(node)
|
||||
new_node = P.Cast()
|
||||
arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
|
||||
"mindspore.float16"])
|
||||
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
||||
targets=['x_cast_{}'.format(node.get_name())],
|
||||
args=arg,
|
||||
name='incast_{}{}'.format(node.get_name(), idx))
|
||||
stree.insert(position, new_cast_node)
|
||||
node.set_arg_by_node(idx, new_cast_node)
|
||||
# insert cast before the Cell operators in white_list
|
||||
elif node.get_instance_type() in AMP_WHITE_LIST_Cell:
|
||||
in_white_list = True
|
||||
node.get_instance().to_float(mstype.float16)
|
||||
|
||||
# insert cast after the operators in white_list
|
||||
if in_white_list:
|
||||
position = stree.after(node)
|
||||
new_node = P.Cast()
|
||||
arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
||||
"mindspore.float32"])
|
||||
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
||||
targets=['x_cast_{}'.format(node.get_name())],
|
||||
args=arg,
|
||||
name='outcast_{}'.format(node.get_name()))
|
||||
for i in range(len(node.get_users())):
|
||||
follow_node = node.get_users()[i]
|
||||
stree.insert(position, new_cast_node)
|
||||
idx = follow_node.get_args().index(node.get_targets()[0])
|
||||
follow_node.set_arg_by_node(idx, new_cast_node)
|
||||
else:
|
||||
_auto_white_list(subcell, white_list)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
||||
_insert_cast_operator(substree)
|
||||
|
||||
|
||||
def _removed_cast_pair(node):
|
||||
"""the cast pairs should be removed."""
|
||||
for i in range(len(node.get_users())):
|
||||
follow_node = node.get_users()[i]
|
||||
if follow_node.get_instance_type() != P.Cast:
|
||||
return False
|
||||
node_dtype = node.get_args()[1]
|
||||
if len(node.get_users()).__trunc__() == 0:
|
||||
return False
|
||||
follow_node_dtype = node.get_users()[0].get_args()[1]
|
||||
for i in range(1, len(node.get_users())):
|
||||
dtype = node.get_users()[i].get_args()[1]
|
||||
if dtype == follow_node_dtype:
|
||||
continue
|
||||
if i == len(node.get_users()) - 1 and follow_node_dtype != node_dtype:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _remove_duplicated_cast(stree):
|
||||
"""remove the duplicated cast operators."""
|
||||
for node in stree.nodes():
|
||||
if node.get_targets() is None:
|
||||
continue
|
||||
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
if node.get_instance_type() == P.Cast and _removed_cast_pair(node):
|
||||
# remove the following cast node first
|
||||
len_users = len(node.get_users())
|
||||
for i in range(len_users):
|
||||
follow_node = node.get_users()[i]
|
||||
for n in follow_node.get_users():
|
||||
idx = n.get_args().index(follow_node.get_targets()[0])
|
||||
n.set_arg_by_node(idx, node.get_inputs()[0])
|
||||
stree.erase_node(follow_node)
|
||||
# remove the current cast node
|
||||
stree.erase_node(node)
|
||||
else:
|
||||
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
||||
_remove_duplicated_cast(substree)
|
||||
|
||||
|
||||
def _auto_white_list(network):
|
||||
"""process the white list of network."""
|
||||
global STREE
|
||||
STREE = ms.rewrite.SymbolTree.create(network)
|
||||
_insert_cast_operator(STREE)
|
||||
_remove_duplicated_cast(STREE)
|
||||
return STREE.get_network()
|
||||
|
||||
|
||||
def _auto_black_list(network, black_list=None):
|
||||
|
@ -125,13 +227,14 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|||
if amp_level == "O0":
|
||||
pass
|
||||
elif amp_level == "O1":
|
||||
_auto_white_list(network)
|
||||
return _auto_white_list(network)
|
||||
elif amp_level == "O2":
|
||||
_auto_black_list(network)
|
||||
elif amp_level == "O3":
|
||||
network.to_float(mstype.float16)
|
||||
else:
|
||||
raise ValueError("The amp level {} is not supported".format(amp_level))
|
||||
return network
|
||||
|
||||
|
||||
def _do_keep_batchnorm_fp32(network):
|
||||
|
@ -214,7 +317,7 @@ def _check_level(level, boost_level):
|
|||
return level, enable_boost
|
||||
|
||||
|
||||
def _add_loss_network(network, loss_fn, cast_model_type):
|
||||
def _add_loss_network(network, loss_fn):
|
||||
"""Add loss network."""
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
|
@ -233,10 +336,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|||
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
||||
|
||||
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
||||
if cast_model_type == mstype.float16:
|
||||
network = WithLossCell(network, loss_fn)
|
||||
else:
|
||||
network = nn.WithLossCell(network, loss_fn)
|
||||
return network
|
||||
|
||||
|
||||
|
@ -304,7 +404,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
auto_mixed_precision(network, level)
|
||||
|
||||
if loss_fn:
|
||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
||||
network = _add_loss_network(network, loss_fn)
|
||||
|
||||
loss_scale = 1.0
|
||||
if config["loss_scale_manager"] is not None:
|
||||
|
|
|
@ -179,3 +179,48 @@ def test_sit_auto_mix_precision_model_o2():
|
|||
model_pynative.train(1, dataset2, dataset_sink_mode=False)
|
||||
out_pynative = model_pynative.predict(Tensor(input_data))
|
||||
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@security_off_wrap
|
||||
def test_sit_auto_mix_precision_model_o1():
|
||||
"""
|
||||
Feature: Test the O1 level auto mixed precision
|
||||
Description: input O1 level to Model interface
|
||||
Expectation: success.
|
||||
"""
|
||||
input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
|
||||
dataset1 = FakeData(size=32,
|
||||
batch_size=32,
|
||||
image_size=(3, 224, 224),
|
||||
num_classes=10,
|
||||
fakedata_mode=FakeDataInitMode.OnesInit)
|
||||
dataset2 = FakeData(size=32,
|
||||
batch_size=32,
|
||||
image_size=(3, 224, 224),
|
||||
num_classes=10,
|
||||
fakedata_mode=FakeDataInitMode.OnesInit)
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(save_graphs=True, save_graphs_path='./test_amp_o1')
|
||||
net = Net(3, 10)
|
||||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
model = Model(net, loss, opt, amp_level="O1")
|
||||
model.train(1, dataset1, dataset_sink_mode=False)
|
||||
clean_all_ir_files('./test_amp_o1/')
|
||||
out_graph = model.predict(Tensor(input_data))
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net_pynative = Net(3, 10)
|
||||
opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009)
|
||||
loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O1")
|
||||
model_pynative.train(1, dataset2, dataset_sink_mode=False)
|
||||
out_pynative = model_pynative.predict(Tensor(input_data))
|
||||
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
|
||||
|
|
Loading…
Reference in New Issue