!22255 recompute_interface_modify

Merge pull request !22255 from yao_yf/recompute_interface_modify
This commit is contained in:
i-robot 2021-09-01 08:09:36 +00:00 committed by Gitee
commit d87d0e07c2
4 changed files with 37 additions and 60 deletions

View File

@ -118,6 +118,7 @@ class Cell(Cell_):
self.cell_type = None
self._auto_parallel_compile_and_run = False
self.cast = Cast()
self._has_config_recompute = False
def __getstate__(self):
base = Cell_.__getstate__(self)
@ -1298,6 +1299,9 @@ class Cell(Cell_):
self._scope = self._scope[len(prefix):]
def _mp_comm_recompute(self, mp_comm_recompute=True):
"""
Set the model parallel communication in cell recomputed.
"""
for _, value in self._primitives.items():
if value:
value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
@ -1305,17 +1309,25 @@ class Cell(Cell_):
cell._mp_comm_recompute(mp_comm_recompute)
def _recompute(self, mode=True, output_recompute=False):
"""
Set the cell recomputed.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode)
Validator.check_bool(output_recompute)
if not self._has_config_recompute:
self._has_config_recompute = True
else:
raise RuntimeError("The recompute interface can be configured only once."
" When the parent cell is configured, the child cell should not be configured")
self._set_recompute_scope(mode)
if mode and not output_recompute:
self.add_flags(output_no_recompute=True)
for cell in self.cells():
cell._recompute(mode, True)
@args_type_check(mode=bool, output_recompute=bool, mp_comm_recompute=bool)
@args_type_check(mp_comm_recompute=bool, optimizer_shard_comm_recompute=bool)
def recompute(self, **kwargs):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
@ -1328,28 +1340,25 @@ class Cell(Cell_):
is not guaranteed currently.
- If the recompute api of a primitive in this cell is also called, the recompute mode of this
primitive is subject to the recompute api of the primitive.
- The interface can be configured only once.
Therefore, when the parent cell is configured, the child cell should not be configured.
- When the memory remains after applying the recompute, configuring 'mp_comm_recompute=True'
to improve performance if necessary.
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
output_recompute (bool): Specifies whether the output of this cell is recomputed when
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
mp_comm_recompute (bool): Specifies whether the model parallel communication operators in the
cell is recomputed in auto parallel or semi auto parallel mode. Default: True.
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
optimizer_shard_comm_recompute (bool): Specifies whether the communication operator allgathers
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
"""
if not kwargs:
self._recompute()
if 'mode' in kwargs.keys() or 'output_recompute' in kwargs.keys():
mode = True
output_recompute = False
if 'mode' in kwargs.keys():
mode = kwargs['mode']
if 'output_recompute' in kwargs.keys():
output_recompute = kwargs['output_recompute']
self._recompute(mode, output_recompute)
self._recompute()
if 'mp_comm_recompute' in kwargs.keys():
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
if 'optimizer_shard_comm_recompute' in kwargs.keys():
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
"are not support recomputation")
for key, _ in kwargs.items():
if key not in ('mode', 'output_recompute', 'mp_comm_recompute'):
if key not in ('mp_comm_recompute', 'optimizer_shard_comm_recompute'):
raise ValueError("Recompute keyword %s is not recognized!" % key)

View File

@ -45,6 +45,7 @@ approvers:
- zhoufeng54
- zlq2020
- zqstar
- yao_yf
reviewers:
- nicholas_yhr
- liubuyu

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import pytest
import mindspore.context as context
import mindspore.nn as nn
@ -35,48 +36,18 @@ def test_set_recompute_true():
net.pool.recompute()
assert net.pool.get_scope() == recompute_prefix
def test_set_recompute_false():
def test_set_recompute_true_with_mp_comm_recompute():
net = Net()
net.pool.recompute(mode=False)
assert net.pool.get_scope() is None
net.pool.recompute(mp_comm_recompute=True)
assert net.pool.get_scope() == recompute_prefix
def test_set_recompute_true_with_mp_comm_recompute_false():
net = Net()
net.pool.recompute(mp_comm_recompute=False)
assert net.pool.get_scope() == recompute_prefix
def test_set_recompute_true_twice():
net = Net()
net.pool.recompute()
net.pool.recompute()
assert net.pool.get_scope() == recompute_prefix
def test_set_recompute_false_twice():
net = Net()
net.pool.recompute(mode=False)
net.pool.recompute(mode=False)
assert net.pool.get_scope() is None
def test_reset_recompute1():
net = Net()
net.pool.recompute(mode=True)
net.pool.recompute(mode=False)
assert net.pool.get_scope() == ""
def test_reset_recompute2():
net = Net()
net.pool.recompute(mode=False)
net.pool.recompute(mode=True)
assert net.pool.get_scope() == recompute_prefix
def test_set_scope_and_set_recompute_repeatedly():
net = Net()
net.pool.recompute(mode=True)
assert net.pool.get_scope() == recompute_prefix
net.pool.recompute(mode=False)
assert net.pool.get_scope() == ""
net.pool.recompute(mode=True)
assert net.pool.get_scope() == recompute_prefix
net.pool.recompute(mode=False)
assert net.pool.get_scope() == ""
with pytest.raises(RuntimeError):
net.pool.recompute()

View File

@ -43,9 +43,6 @@ class DenseMutMulNet(nn.Cell):
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.fc1.recompute()
self.fc2.recompute()
self.fc3.recompute()
self.fc1.matmul.shard(((1, 1), (1, 8)))
self.fc2.matmul.shard(((1, 1), (1, 8)))
self.fc3.matmul.shard(((1, 1), (1, 8)))
@ -55,7 +52,6 @@ class DenseMutMulNet(nn.Cell):
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
self.matmul_cell = MatMulCell()
self.matmul_cell.recompute()
self.fc1.recompute(mp_comm_recompute=False)
self.fc2.recompute(mp_comm_recompute=False)
self.fc3.recompute(mp_comm_recompute=False)