forked from mindspore-Ecosystem/mindspore
!996 [Auto parallel] Supporting VirtualDataset having three inputs
Merge pull request !996 from Xiaoda/supporting-virtualdataset-has-three-inputs
This commit is contained in:
commit
57085bb18d
|
@ -18,7 +18,7 @@ Wrap cells for networks.
|
|||
Use the Wrapper to combine the loss or build the training steps.
|
||||
"""
|
||||
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
|
||||
ParameterUpdate, GetNextSingleOp
|
||||
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
|
||||
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
||||
from .grad_reducer import DistributedGradReducer
|
||||
|
||||
|
@ -33,5 +33,6 @@ __all__ = [
|
|||
"DistributedGradReducer",
|
||||
"ParameterUpdate",
|
||||
"DynamicLossScaleUpdateCell",
|
||||
"FixedLossScaleUpdateCell"
|
||||
"FixedLossScaleUpdateCell",
|
||||
"VirtualDatasetCellTriple"
|
||||
]
|
||||
|
|
|
@ -278,6 +278,36 @@ class _VirtualDatasetCell(Cell):
|
|||
return self._backbone(data_, label_)
|
||||
|
||||
|
||||
class VirtualDatasetCellTriple(Cell):
|
||||
"""
|
||||
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
|
||||
|
||||
VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
|
||||
of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
|
||||
dynamically during the graph compile process.
|
||||
|
||||
Note:
|
||||
Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
|
||||
_VirtualDatasetCell.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> net = VirtualDatasetCellTriple(net)
|
||||
"""
|
||||
|
||||
def __init__(self, backbone):
|
||||
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._virtual_dataset = _VirtualDataset()
|
||||
|
||||
def construct(self, a, b, c):
|
||||
a_, b_, c_ = self._virtual_dataset(a, b, c)
|
||||
return self._backbone(a_, b_, c_)
|
||||
|
||||
|
||||
class WithEvalCell(Cell):
|
||||
r"""
|
||||
Cell that returns loss, output and label for evaluation.
|
||||
|
|
|
@ -21,6 +21,7 @@ import mindspore as ms
|
|||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||
from mindspore import context
|
||||
|
||||
|
||||
|
@ -73,6 +74,29 @@ def test_virtual_dataset_3_input():
|
|||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
def test_virtualdataset_cell_3_inputs():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy0, strategy1, strategy2, strategy3):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().set_strategy(strategy1)
|
||||
self.matmul2 = P.MatMul().set_strategy(strategy2)
|
||||
self.gelu = P.Gelu().set_strategy(strategy3)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.gelu(self.matmul1(x, y))
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None))))
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_virtual_dataset_3_input()
|
||||
|
|
Loading…
Reference in New Issue