forked from mindspore-Ecosystem/mindspore
!22255 recompute_interface_modify
Merge pull request !22255 from yao_yf/recompute_interface_modify
This commit is contained in:
commit
d87d0e07c2
|
@ -118,6 +118,7 @@ class Cell(Cell_):
|
|||
self.cell_type = None
|
||||
self._auto_parallel_compile_and_run = False
|
||||
self.cast = Cast()
|
||||
self._has_config_recompute = False
|
||||
|
||||
def __getstate__(self):
|
||||
base = Cell_.__getstate__(self)
|
||||
|
@ -1298,6 +1299,9 @@ class Cell(Cell_):
|
|||
self._scope = self._scope[len(prefix):]
|
||||
|
||||
def _mp_comm_recompute(self, mp_comm_recompute=True):
|
||||
"""
|
||||
Set the model parallel communication in cell recomputed.
|
||||
"""
|
||||
for _, value in self._primitives.items():
|
||||
if value:
|
||||
value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
|
||||
|
@ -1305,17 +1309,25 @@ class Cell(Cell_):
|
|||
cell._mp_comm_recompute(mp_comm_recompute)
|
||||
|
||||
def _recompute(self, mode=True, output_recompute=False):
|
||||
"""
|
||||
Set the cell recomputed.
|
||||
"""
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
raise TypeError("Recompute is not supported in pynative mode currently.")
|
||||
Validator.check_bool(mode)
|
||||
Validator.check_bool(output_recompute)
|
||||
if not self._has_config_recompute:
|
||||
self._has_config_recompute = True
|
||||
else:
|
||||
raise RuntimeError("The recompute interface can be configured only once."
|
||||
" When the parent cell is configured, the child cell should not be configured")
|
||||
self._set_recompute_scope(mode)
|
||||
if mode and not output_recompute:
|
||||
self.add_flags(output_no_recompute=True)
|
||||
for cell in self.cells():
|
||||
cell._recompute(mode, True)
|
||||
|
||||
@args_type_check(mode=bool, output_recompute=bool, mp_comm_recompute=bool)
|
||||
@args_type_check(mp_comm_recompute=bool, optimizer_shard_comm_recompute=bool)
|
||||
def recompute(self, **kwargs):
|
||||
"""
|
||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
||||
|
@ -1328,28 +1340,25 @@ class Cell(Cell_):
|
|||
is not guaranteed currently.
|
||||
- If the recompute api of a primitive in this cell is also called, the recompute mode of this
|
||||
primitive is subject to the recompute api of the primitive.
|
||||
- The interface can be configured only once.
|
||||
Therefore, when the parent cell is configured, the child cell should not be configured.
|
||||
- When the memory remains after applying the recompute, configuring 'mp_comm_recompute=True'
|
||||
to improve performance if necessary.
|
||||
|
||||
Args:
|
||||
mode (bool): Specifies whether the cell is recomputed. Default: True.
|
||||
output_recompute (bool): Specifies whether the output of this cell is recomputed when
|
||||
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
|
||||
mp_comm_recompute (bool): Specifies whether the model parallel communication operators in the
|
||||
cell is recomputed in auto parallel or semi auto parallel mode. Default: True.
|
||||
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
||||
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
|
||||
optimizer_shard_comm_recompute (bool): Specifies whether the communication operator allgathers
|
||||
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
|
||||
"""
|
||||
if not kwargs:
|
||||
self._recompute()
|
||||
if 'mode' in kwargs.keys() or 'output_recompute' in kwargs.keys():
|
||||
mode = True
|
||||
output_recompute = False
|
||||
if 'mode' in kwargs.keys():
|
||||
mode = kwargs['mode']
|
||||
if 'output_recompute' in kwargs.keys():
|
||||
output_recompute = kwargs['output_recompute']
|
||||
self._recompute(mode, output_recompute)
|
||||
self._recompute()
|
||||
if 'mp_comm_recompute' in kwargs.keys():
|
||||
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
||||
if 'optimizer_shard_comm_recompute' in kwargs.keys():
|
||||
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
|
||||
"are not support recomputation")
|
||||
for key, _ in kwargs.items():
|
||||
if key not in ('mode', 'output_recompute', 'mp_comm_recompute'):
|
||||
if key not in ('mp_comm_recompute', 'optimizer_shard_comm_recompute'):
|
||||
raise ValueError("Recompute keyword %s is not recognized!" % key)
|
||||
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ approvers:
|
|||
- zhoufeng54
|
||||
- zlq2020
|
||||
- zqstar
|
||||
- yao_yf
|
||||
reviewers:
|
||||
- nicholas_yhr
|
||||
- liubuyu
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
@ -35,48 +36,18 @@ def test_set_recompute_true():
|
|||
net.pool.recompute()
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
|
||||
|
||||
def test_set_recompute_false():
|
||||
def test_set_recompute_true_with_mp_comm_recompute():
|
||||
net = Net()
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() is None
|
||||
net.pool.recompute(mp_comm_recompute=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
|
||||
def test_set_recompute_true_with_mp_comm_recompute_false():
|
||||
net = Net()
|
||||
net.pool.recompute(mp_comm_recompute=False)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
|
||||
def test_set_recompute_true_twice():
|
||||
net = Net()
|
||||
net.pool.recompute()
|
||||
net.pool.recompute()
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
|
||||
|
||||
def test_set_recompute_false_twice():
|
||||
net = Net()
|
||||
net.pool.recompute(mode=False)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() is None
|
||||
|
||||
|
||||
def test_reset_recompute1():
|
||||
net = Net()
|
||||
net.pool.recompute(mode=True)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() == ""
|
||||
|
||||
|
||||
def test_reset_recompute2():
|
||||
net = Net()
|
||||
net.pool.recompute(mode=False)
|
||||
net.pool.recompute(mode=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
|
||||
|
||||
def test_set_scope_and_set_recompute_repeatedly():
|
||||
net = Net()
|
||||
net.pool.recompute(mode=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() == ""
|
||||
net.pool.recompute(mode=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() == ""
|
||||
with pytest.raises(RuntimeError):
|
||||
net.pool.recompute()
|
||||
|
|
|
@ -43,9 +43,6 @@ class DenseMutMulNet(nn.Cell):
|
|||
self.fc2 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc3 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc4 = nn.Dense(768, 768, activation='relu')
|
||||
self.fc1.recompute()
|
||||
self.fc2.recompute()
|
||||
self.fc3.recompute()
|
||||
self.fc1.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc2.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc3.matmul.shard(((1, 1), (1, 8)))
|
||||
|
@ -55,7 +52,6 @@ class DenseMutMulNet(nn.Cell):
|
|||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.matmul_cell = MatMulCell()
|
||||
self.matmul_cell.recompute()
|
||||
self.fc1.recompute(mp_comm_recompute=False)
|
||||
self.fc2.recompute(mp_comm_recompute=False)
|
||||
self.fc3.recompute(mp_comm_recompute=False)
|
||||
|
|
Loading…
Reference in New Issue