!996 [Auto parallel] Supporting VirtualDataset having three inputs

Merge pull request !996 from Xiaoda/supporting-virtualdataset-has-three-inputs
This commit is contained in:
mindspore-ci-bot 2020-05-11 18:59:09 +08:00 committed by Gitee
commit 57085bb18d
3 changed files with 57 additions and 2 deletions

View File

@ -18,7 +18,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps.
"""
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
ParameterUpdate, GetNextSingleOp
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer
@ -33,5 +33,6 @@ __all__ = [
"DistributedGradReducer",
"ParameterUpdate",
"DynamicLossScaleUpdateCell",
"FixedLossScaleUpdateCell"
"FixedLossScaleUpdateCell",
"VirtualDatasetCellTriple"
]

View File

@ -278,6 +278,36 @@ class _VirtualDatasetCell(Cell):
return self._backbone(data_, label_)
class VirtualDatasetCellTriple(Cell):
"""
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
dynamically during the graph compile process.
Note:
Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
_VirtualDatasetCell.
Args:
backbone (Cell): The target network to wrap.
Examples:
>>> net = Net()
>>> net = VirtualDatasetCellTriple(net)
"""
def __init__(self, backbone):
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
self._backbone = backbone
self._virtual_dataset = _VirtualDataset()
def construct(self, a, b, c):
a_, b_, c_ = self._virtual_dataset(a, b, c)
return self._backbone(a_, b_, c_)
class WithEvalCell(Cell):
r"""
Cell that returns loss, output and label for evaluation.

View File

@ -21,6 +21,7 @@ import mindspore as ms
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from mindspore import context
@ -73,6 +74,29 @@ def test_virtual_dataset_3_input():
net.set_auto_parallel()
_executor.compile(net, x, y, b)
def test_virtualdataset_cell_3_inputs():
class Net(nn.Cell):
def __init__(self, strategy0, strategy1, strategy2, strategy3):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul2 = P.MatMul().set_strategy(strategy2)
self.gelu = P.Gelu().set_strategy(strategy3)
def construct(self, x, y, b):
out = self.gelu(self.matmul1(x, y))
out = self.matmul2(out, b)
return out
net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None))))
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
net.set_auto_parallel()
_executor.compile(net, x, y, b)
if __name__ == '__main__':
test_virtual_dataset_3_input()