!8983 support variabel length virtual_dataset in auto parallel

From: @yao_yf
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-28 11:07:27 +08:00 committed by Gitee
commit a146f982bc
6 changed files with 61 additions and 50 deletions

View File

@ -90,10 +90,6 @@ class VirtualDatasetEliminater : public AnfVisitor {
std::vector<AnfNodePtr> args;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
if (args.size() == 1) {
return args.front();
}
(void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
return node->func_graph()->NewCNode(args);

View File

@ -304,9 +304,9 @@ class _VirtualDatasetCell(Cell):
self._backbone = backbone
self._virtual_dataset = _VirtualDataset()
def construct(self, data, label):
data_, label_ = self._virtual_dataset(data, label)
return self._backbone(data_, label_)
def construct(self, *inputs):
output = self._virtual_dataset(*inputs)
return self._backbone(*output)
class VirtualDatasetCellTriple(Cell):

View File

@ -689,13 +689,9 @@ class _VirtualDataset(PrimitiveWithInfer):
"""init"""
def infer_shape(self, *args):
if len(args) == 1:
return args[0]
return args
def infer_dtype(self, *args):
if len(args) == 1:
return args[0]
return args

View File

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init
from mindspore.parallel import set_algo_parameters
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from .test_auto_parallel_resnet import resnet50
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
init()
def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
np.random.seed(6)
input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32))
net = resnet50(num_classes)
model = Model(net)
model.predict(input_np)

View File

@ -21,7 +21,7 @@ from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
context.set_context(mode=context.GRAPH_MODE)
@ -32,7 +32,6 @@ grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network, strategy3, strategy4, axis):
super(NetWithLoss, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.one_hot = P.OneHot(axis=axis).shard(strategy3)
self.on_value = Tensor(2.0, ms.float32)
self.off_value = Tensor(1.0, ms.float32)
@ -40,9 +39,8 @@ class NetWithLoss(nn.Cell):
self.network = network
def construct(self, x, y, b):
b_virtual = self.virtual_dataset(b)
predict = self.network(x, y)
label = self.one_hot(b_virtual, 64, self.on_value, self.off_value)
label = self.one_hot(b, 64, self.on_value, self.off_value)
return self.loss(predict, label)[0]
@ -68,7 +66,7 @@ class Net(nn.Cell):
def compile_graph(strategy1, strategy2, strategy3, strategy4, auto=False, onthot_axis=-1):
net = GradWrap(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis))
net = GradWrap(_VirtualDatasetCell(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis)))
net.set_auto_parallel()
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")

View File

@ -26,7 +26,7 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel import set_algo_parameters
from mindspore.train import Model
from mindspore.context import ParallelMode
@ -204,14 +204,12 @@ class GradWrap(nn.Cell):
class ReshapeNet1(nn.Cell):
def __init__(self, strategy0):
super(ReshapeNet1, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(strategy0)
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
self.reshape2 = P.Reshape()
def construct(self, x):
x = self.virtual_dataset(x)
x = self.reshape(x, (256, 25088))
x = self.matmul(x, self.matmul_weight)
x = self.reshape2(x, (256 * 256,))
@ -221,7 +219,6 @@ class ReshapeNet1(nn.Cell):
class ReshapeNet2(nn.Cell):
def __init__(self, strategy0):
super(ReshapeNet2, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(strategy0)
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
@ -230,7 +227,6 @@ class ReshapeNet2(nn.Cell):
self.reshape3 = P.Reshape()
def construct(self, x):
x = self.virtual_dataset(x)
x = self.reshape(x, (256, 25088))
x = self.matmul(x, self.matmul_weight)
x = self.reshape2(x, (256 * 256,))
@ -242,7 +238,6 @@ class ReshapeNet2(nn.Cell):
class ReshapeNet3(nn.Cell):
def __init__(self, strategy0):
super(ReshapeNet3, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(strategy0)
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
@ -251,7 +246,6 @@ class ReshapeNet3(nn.Cell):
self.reshape3 = P.Reshape()
def construct(self, x):
x = self.virtual_dataset(x)
x = self.reshape(x, (256, 25088))
x = self.matmul(x, self.matmul_weight)
x = self.reshape2(x, (256 * 256,))
@ -263,14 +257,12 @@ class ReshapeNet3(nn.Cell):
class ReshapeNet4(nn.Cell):
def __init__(self, strategy0):
super(ReshapeNet4, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.reshape = P.Reshape()
self.reshape2 = P.Reshape()
self.matmul = P.MatMul().shard(strategy0)
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
def construct(self, x):
x = self.virtual_dataset(x)
x = self.reshape(x, (256, 25088))
w = self.reshape2(self.matmul_weight, (25088, 256))
x = self.matmul(x, w)
@ -280,14 +272,12 @@ class ReshapeNet4(nn.Cell):
class ReshapeNet5(nn.Cell):
def __init__(self, strategy0):
super(ReshapeNet5, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.reshape = P.Reshape()
self.matmul1 = P.MatMul().shard(strategy0)
self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
self.matmul2 = P.MatMul().shard(strategy0)
def construct(self, x):
x = self.virtual_dataset(x)
x = self.reshape(x, (256, 25088))
matmul1_o = self.matmul1(x, self.matmul1_weight)
matmul2_o = self.matmul2(matmul1_o, x)
@ -297,7 +287,6 @@ class ReshapeNet5(nn.Cell):
class ReshapeNet6(nn.Cell):
def __init__(self, strategy0):
super(ReshapeNet6, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.reshape = P.Reshape()
self.matmul1_1 = P.MatMul().shard(strategy0)
self.matmul1_2 = P.MatMul().shard(strategy0)
@ -306,7 +295,6 @@ class ReshapeNet6(nn.Cell):
self.add = P.TensorAdd()
def construct(self, x):
x = self.virtual_dataset(x)
x = self.reshape(x, (256, 25088))
matmul1_1_o = self.matmul1_1(x, self.matmul1_weight)
matmul1_2_o = self.matmul1_2(x, self.matmul1_weight)
@ -334,32 +322,32 @@ def reshape_net2(backbone):
def test_reshape_net1_1():
reshape_net2(ReshapeNet1(((1, 8), (8, 1))))
reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1)))))
def test_reshape_net1_2():
reshape_net2(ReshapeNet1(((1, 8), (8, 2))))
reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2)))))
def test_reshape_net2_1():
reshape_net2(ReshapeNet2(((1, 8), (8, 1))))
reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1)))))
def test_reshape_net2_2():
reshape_net2(ReshapeNet2(((1, 8), (8, 2))))
reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2)))))
def test_reshape_net3_1():
reshape_net2(ReshapeNet3(((1, 8), (8, 1))))
reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1)))))
def test_reshape_net3_2():
reshape_net2(ReshapeNet3(((1, 8), (8, 2))))
reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
def test_reshape_net4_1():
try:
reshape_net2(ReshapeNet4(((1, 8), (8, 1))))
reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 1)))))
except ValueError:
pass
except TypeError:
@ -370,7 +358,7 @@ def test_reshape_net4_1():
def test_reshape_net4_2():
try:
reshape_net2(ReshapeNet4(((1, 8), (8, 2))))
reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2)))))
except ValueError:
pass
except TypeError:
@ -380,19 +368,19 @@ def test_reshape_net4_2():
def test_reshape_net5_1():
reshape_net2(ReshapeNet5(((1, 8), (8, 1))))
reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1)))))
def test_reshape_net5_2():
reshape_net2(ReshapeNet5(((1, 8), (8, 2))))
reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 2)))))
def test_reshape_net6_1():
reshape_net2(ReshapeNet6(((1, 8), (8, 1))))
reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 1)))))
def test_reshape_net6_2():
reshape_net2(ReshapeNet6(((1, 8), (8, 2))))
reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2)))))
class TrainOneStepCell(nn.Cell):
@ -453,39 +441,37 @@ def reshape_common2(parallel_mode, net):
def test_reshape_common2_0():
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 1))))
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1)))))
def test_reshape_common2_1():
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 2))))
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2)))))
def test_reshape_common2_2():
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 1))))
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1)))))
def test_reshape_common2_3():
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 2))))
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2)))))
def test_reshape_common2_4():
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 1))))
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1)))))
def test_reshape_common2_5():
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 2))))
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
class BatchNormReshapeNet(nn.Cell):
def __init__(self):
super(BatchNormReshapeNet, self).__init__()
self.vd = P._VirtualDataset()
self.batch_norm = nn.BatchNorm1d(512, affine=False)
self.reshape = P.Reshape()
self.prelu = nn.PReLU(channel=256)
def construct(self, x):
x = self.vd(x)
x = self.batch_norm(x)
x = self.reshape(x, (512, 256))
x = self.prelu(x)
@ -499,7 +485,7 @@ def test_batchnorm_reshape_train():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01)
net = GradWrap(NetWithLoss(BatchNormReshapeNet()))
net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet())))
compile_net(net, input_)