add amp o1 level
This commit is contained in:
parent
3324fea63f
commit
8e26ebdfbc
|
@ -29,6 +29,7 @@ from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_contex
|
||||||
from mindspore.version import __version__
|
from mindspore.version import __version__
|
||||||
from mindspore.profiler import Profiler
|
from mindspore.profiler import Profiler
|
||||||
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters
|
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters
|
||||||
|
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["run_check"]
|
__all__ = ["run_check"]
|
||||||
|
@ -38,4 +39,5 @@ __all__.extend(train.__all__)
|
||||||
__all__.extend(log.__all__)
|
__all__.extend(log.__all__)
|
||||||
__all__.extend(context.__all__)
|
__all__.extend(context.__all__)
|
||||||
__all__.extend(parallel.__all__)
|
__all__.extend(parallel.__all__)
|
||||||
|
__all__.extend(rewrite.__all__)
|
||||||
__all__.append("Profiler")
|
__all__.append("Profiler")
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""Auto mixed precision."""
|
"""Auto mixed precision."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
|
@ -25,18 +26,40 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.parallel._utils import _get_pipeline_stages
|
from mindspore.parallel._utils import _get_pipeline_stages
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||||
from mindspore import boost, context
|
from mindspore import boost, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
AMP_WHITE_LIST = (
|
STREE = None
|
||||||
nn.Dense,
|
|
||||||
|
|
||||||
|
AMP_WHITE_LIST_Cell = (
|
||||||
nn.Conv1d,
|
nn.Conv1d,
|
||||||
nn.Conv2d,
|
nn.Conv2d,
|
||||||
nn.Conv3d,
|
nn.Conv3d,
|
||||||
nn.Conv1dTranspose,
|
nn.Conv1dTranspose,
|
||||||
nn.Conv2dTranspose,
|
nn.Conv2dTranspose,
|
||||||
nn.Conv3dTranspose
|
nn.Conv3dTranspose,
|
||||||
|
nn.Dense,
|
||||||
|
nn.LSTMCell,
|
||||||
|
nn.RNNCell,
|
||||||
|
nn.GRUCell
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AMP_WHITE_LIST_OPS = (
|
||||||
|
P.Conv2D,
|
||||||
|
P.Conv3D,
|
||||||
|
P.Conv2DTranspose,
|
||||||
|
P.Conv3DTranspose,
|
||||||
|
P.Conv2DBackpropInput,
|
||||||
|
P.MatMul,
|
||||||
|
P.BatchMatMul,
|
||||||
|
P.PReLU,
|
||||||
|
P.ReLU,
|
||||||
|
P.Ger
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AMP_BLACK_LIST = (
|
AMP_BLACK_LIST = (
|
||||||
nn.BatchNorm1d,
|
nn.BatchNorm1d,
|
||||||
nn.BatchNorm2d,
|
nn.BatchNorm2d,
|
||||||
|
@ -67,23 +90,102 @@ class _OutputTo32(nn.Cell):
|
||||||
return F.cast(self._op(x), mstype.float32)
|
return F.cast(self._op(x), mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
def _auto_white_list(network, white_list=None):
|
def _insert_cast_operator(stree):
|
||||||
"""process the white list of network."""
|
"""insert cast for operators in white_list."""
|
||||||
if white_list is None:
|
new_cast_node = None
|
||||||
white_list = AMP_WHITE_LIST
|
for node in stree.nodes():
|
||||||
cells = network.name_cells()
|
if node.get_targets() is None:
|
||||||
change = False
|
|
||||||
for name in cells:
|
|
||||||
subcell = cells[name]
|
|
||||||
if subcell == network:
|
|
||||||
continue
|
continue
|
||||||
if isinstance(subcell, white_list):
|
in_white_list = False
|
||||||
network._cells[name] = _OutputTo32(subcell.to_float(mstype.float16))
|
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||||
change = True
|
# insert cast before the primitive operators in white_list
|
||||||
|
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
|
||||||
|
in_white_list = True
|
||||||
|
for idx in range(len(node.get_inputs())):
|
||||||
|
position = stree.before(node)
|
||||||
|
new_node = P.Cast()
|
||||||
|
arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
|
||||||
|
"mindspore.float16"])
|
||||||
|
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
||||||
|
targets=['x_cast_{}'.format(node.get_name())],
|
||||||
|
args=arg,
|
||||||
|
name='incast_{}{}'.format(node.get_name(), idx))
|
||||||
|
stree.insert(position, new_cast_node)
|
||||||
|
node.set_arg_by_node(idx, new_cast_node)
|
||||||
|
# insert cast before the Cell operators in white_list
|
||||||
|
elif node.get_instance_type() in AMP_WHITE_LIST_Cell:
|
||||||
|
in_white_list = True
|
||||||
|
node.get_instance().to_float(mstype.float16)
|
||||||
|
|
||||||
|
# insert cast after the operators in white_list
|
||||||
|
if in_white_list:
|
||||||
|
position = stree.after(node)
|
||||||
|
new_node = P.Cast()
|
||||||
|
arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
||||||
|
"mindspore.float32"])
|
||||||
|
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
||||||
|
targets=['x_cast_{}'.format(node.get_name())],
|
||||||
|
args=arg,
|
||||||
|
name='outcast_{}'.format(node.get_name()))
|
||||||
|
for i in range(len(node.get_users())):
|
||||||
|
follow_node = node.get_users()[i]
|
||||||
|
stree.insert(position, new_cast_node)
|
||||||
|
idx = follow_node.get_args().index(node.get_targets()[0])
|
||||||
|
follow_node.set_arg_by_node(idx, new_cast_node)
|
||||||
else:
|
else:
|
||||||
_auto_white_list(subcell, white_list)
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
||||||
if isinstance(network, nn.SequentialCell) and change:
|
_insert_cast_operator(substree)
|
||||||
network.cell_list = list(network.cells())
|
|
||||||
|
|
||||||
|
def _removed_cast_pair(node):
|
||||||
|
"""the cast pairs should be removed."""
|
||||||
|
for i in range(len(node.get_users())):
|
||||||
|
follow_node = node.get_users()[i]
|
||||||
|
if follow_node.get_instance_type() != P.Cast:
|
||||||
|
return False
|
||||||
|
node_dtype = node.get_args()[1]
|
||||||
|
if len(node.get_users()).__trunc__() == 0:
|
||||||
|
return False
|
||||||
|
follow_node_dtype = node.get_users()[0].get_args()[1]
|
||||||
|
for i in range(1, len(node.get_users())):
|
||||||
|
dtype = node.get_users()[i].get_args()[1]
|
||||||
|
if dtype == follow_node_dtype:
|
||||||
|
continue
|
||||||
|
if i == len(node.get_users()) - 1 and follow_node_dtype != node_dtype:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_duplicated_cast(stree):
|
||||||
|
"""remove the duplicated cast operators."""
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_targets() is None:
|
||||||
|
continue
|
||||||
|
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||||
|
if node.get_instance_type() == P.Cast and _removed_cast_pair(node):
|
||||||
|
# remove the following cast node first
|
||||||
|
len_users = len(node.get_users())
|
||||||
|
for i in range(len_users):
|
||||||
|
follow_node = node.get_users()[i]
|
||||||
|
for n in follow_node.get_users():
|
||||||
|
idx = n.get_args().index(follow_node.get_targets()[0])
|
||||||
|
n.set_arg_by_node(idx, node.get_inputs()[0])
|
||||||
|
stree.erase_node(follow_node)
|
||||||
|
# remove the current cast node
|
||||||
|
stree.erase_node(node)
|
||||||
|
else:
|
||||||
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
||||||
|
_remove_duplicated_cast(substree)
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_white_list(network):
|
||||||
|
"""process the white list of network."""
|
||||||
|
global STREE
|
||||||
|
STREE = ms.rewrite.SymbolTree.create(network)
|
||||||
|
_insert_cast_operator(STREE)
|
||||||
|
_remove_duplicated_cast(STREE)
|
||||||
|
return STREE.get_network()
|
||||||
|
|
||||||
|
|
||||||
def _auto_black_list(network, black_list=None):
|
def _auto_black_list(network, black_list=None):
|
||||||
|
@ -125,13 +227,14 @@ def auto_mixed_precision(network, amp_level="O0"):
|
||||||
if amp_level == "O0":
|
if amp_level == "O0":
|
||||||
pass
|
pass
|
||||||
elif amp_level == "O1":
|
elif amp_level == "O1":
|
||||||
_auto_white_list(network)
|
return _auto_white_list(network)
|
||||||
elif amp_level == "O2":
|
elif amp_level == "O2":
|
||||||
_auto_black_list(network)
|
_auto_black_list(network)
|
||||||
elif amp_level == "O3":
|
elif amp_level == "O3":
|
||||||
network.to_float(mstype.float16)
|
network.to_float(mstype.float16)
|
||||||
else:
|
else:
|
||||||
raise ValueError("The amp level {} is not supported".format(amp_level))
|
raise ValueError("The amp level {} is not supported".format(amp_level))
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
def _do_keep_batchnorm_fp32(network):
|
def _do_keep_batchnorm_fp32(network):
|
||||||
|
@ -214,7 +317,7 @@ def _check_level(level, boost_level):
|
||||||
return level, enable_boost
|
return level, enable_boost
|
||||||
|
|
||||||
|
|
||||||
def _add_loss_network(network, loss_fn, cast_model_type):
|
def _add_loss_network(network, loss_fn):
|
||||||
"""Add loss network."""
|
"""Add loss network."""
|
||||||
|
|
||||||
class WithLossCell(nn.Cell):
|
class WithLossCell(nn.Cell):
|
||||||
|
@ -233,10 +336,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
||||||
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
||||||
|
|
||||||
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
||||||
if cast_model_type == mstype.float16:
|
network = WithLossCell(network, loss_fn)
|
||||||
network = WithLossCell(network, loss_fn)
|
|
||||||
else:
|
|
||||||
network = nn.WithLossCell(network, loss_fn)
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
@ -304,7 +404,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
||||||
auto_mixed_precision(network, level)
|
auto_mixed_precision(network, level)
|
||||||
|
|
||||||
if loss_fn:
|
if loss_fn:
|
||||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
network = _add_loss_network(network, loss_fn)
|
||||||
|
|
||||||
loss_scale = 1.0
|
loss_scale = 1.0
|
||||||
if config["loss_scale_manager"] is not None:
|
if config["loss_scale_manager"] is not None:
|
||||||
|
|
|
@ -179,3 +179,48 @@ def test_sit_auto_mix_precision_model_o2():
|
||||||
model_pynative.train(1, dataset2, dataset_sink_mode=False)
|
model_pynative.train(1, dataset2, dataset_sink_mode=False)
|
||||||
out_pynative = model_pynative.predict(Tensor(input_data))
|
out_pynative = model_pynative.predict(Tensor(input_data))
|
||||||
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
|
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@security_off_wrap
|
||||||
|
def test_sit_auto_mix_precision_model_o1():
|
||||||
|
"""
|
||||||
|
Feature: Test the O1 level auto mixed precision
|
||||||
|
Description: input O1 level to Model interface
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
input_data = np.random.randn(32, 3, 224, 224).astype(np.float32)
|
||||||
|
dataset1 = FakeData(size=32,
|
||||||
|
batch_size=32,
|
||||||
|
image_size=(3, 224, 224),
|
||||||
|
num_classes=10,
|
||||||
|
fakedata_mode=FakeDataInitMode.OnesInit)
|
||||||
|
dataset2 = FakeData(size=32,
|
||||||
|
batch_size=32,
|
||||||
|
image_size=(3, 224, 224),
|
||||||
|
num_classes=10,
|
||||||
|
fakedata_mode=FakeDataInitMode.OnesInit)
|
||||||
|
# graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
context.set_context(save_graphs=True, save_graphs_path='./test_amp_o1')
|
||||||
|
net = Net(3, 10)
|
||||||
|
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.0009)
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||||
|
model = Model(net, loss, opt, amp_level="O1")
|
||||||
|
model.train(1, dataset1, dataset_sink_mode=False)
|
||||||
|
clean_all_ir_files('./test_amp_o1/')
|
||||||
|
out_graph = model.predict(Tensor(input_data))
|
||||||
|
|
||||||
|
# pynative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net_pynative = Net(3, 10)
|
||||||
|
opt_pynative = nn.Momentum(params=net_pynative.trainable_params(), learning_rate=0.001, momentum=0.0009)
|
||||||
|
loss_pynative = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||||
|
model_pynative = Model(net_pynative, loss_pynative, opt_pynative, amp_level="O1")
|
||||||
|
model_pynative.train(1, dataset2, dataset_sink_mode=False)
|
||||||
|
out_pynative = model_pynative.predict(Tensor(input_data))
|
||||||
|
allclose_nparray(out_graph.asnumpy(), out_pynative.asnumpy(), 0.001, 0.001)
|
||||||
|
|
Loading…
Reference in New Issue