forked from mindspore-Ecosystem/mindspore
!8983 support variabel length virtual_dataset in auto parallel
From: @yao_yf Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a146f982bc
|
@ -90,10 +90,6 @@ class VirtualDatasetEliminater : public AnfVisitor {
|
|||
|
||||
std::vector<AnfNodePtr> args;
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
|
||||
if (args.size() == 1) {
|
||||
return args.front();
|
||||
}
|
||||
|
||||
(void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||
|
||||
return node->func_graph()->NewCNode(args);
|
||||
|
|
|
@ -304,9 +304,9 @@ class _VirtualDatasetCell(Cell):
|
|||
self._backbone = backbone
|
||||
self._virtual_dataset = _VirtualDataset()
|
||||
|
||||
def construct(self, data, label):
|
||||
data_, label_ = self._virtual_dataset(data, label)
|
||||
return self._backbone(data_, label_)
|
||||
def construct(self, *inputs):
|
||||
output = self._virtual_dataset(*inputs)
|
||||
return self._backbone(*output)
|
||||
|
||||
|
||||
class VirtualDatasetCellTriple(Cell):
|
||||
|
|
|
@ -689,13 +689,9 @@ class _VirtualDataset(PrimitiveWithInfer):
|
|||
"""init"""
|
||||
|
||||
def infer_shape(self, *args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
def infer_dtype(self, *args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from .test_auto_parallel_resnet import resnet50
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
init()
|
||||
|
||||
def test_train_32k_8p(batch_size=32, num_classes=32768):
|
||||
dev_num = 8
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
np.random.seed(6)
|
||||
input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32))
|
||||
net = resnet50(num_classes)
|
||||
model = Model(net)
|
||||
model.predict(input_np)
|
|
@ -21,7 +21,7 @@ from mindspore import context
|
|||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -32,7 +32,6 @@ grad_all = C.GradOperation(get_all=True)
|
|||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network, strategy3, strategy4, axis):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.one_hot = P.OneHot(axis=axis).shard(strategy3)
|
||||
self.on_value = Tensor(2.0, ms.float32)
|
||||
self.off_value = Tensor(1.0, ms.float32)
|
||||
|
@ -40,9 +39,8 @@ class NetWithLoss(nn.Cell):
|
|||
self.network = network
|
||||
|
||||
def construct(self, x, y, b):
|
||||
b_virtual = self.virtual_dataset(b)
|
||||
predict = self.network(x, y)
|
||||
label = self.one_hot(b_virtual, 64, self.on_value, self.off_value)
|
||||
label = self.one_hot(b, 64, self.on_value, self.off_value)
|
||||
return self.loss(predict, label)[0]
|
||||
|
||||
|
||||
|
@ -68,7 +66,7 @@ class Net(nn.Cell):
|
|||
|
||||
|
||||
def compile_graph(strategy1, strategy2, strategy3, strategy4, auto=False, onthot_axis=-1):
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis))
|
||||
net = GradWrap(_VirtualDatasetCell(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis)))
|
||||
net.set_auto_parallel()
|
||||
if auto:
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.nn.optim.momentum import Momentum
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -204,14 +204,12 @@ class GradWrap(nn.Cell):
|
|||
class ReshapeNet1(nn.Cell):
|
||||
def __init__(self, strategy0):
|
||||
super(ReshapeNet1, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy0)
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
|
||||
self.reshape2 = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.virtual_dataset(x)
|
||||
x = self.reshape(x, (256, 25088))
|
||||
x = self.matmul(x, self.matmul_weight)
|
||||
x = self.reshape2(x, (256 * 256,))
|
||||
|
@ -221,7 +219,6 @@ class ReshapeNet1(nn.Cell):
|
|||
class ReshapeNet2(nn.Cell):
|
||||
def __init__(self, strategy0):
|
||||
super(ReshapeNet2, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy0)
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
|
||||
|
@ -230,7 +227,6 @@ class ReshapeNet2(nn.Cell):
|
|||
self.reshape3 = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.virtual_dataset(x)
|
||||
x = self.reshape(x, (256, 25088))
|
||||
x = self.matmul(x, self.matmul_weight)
|
||||
x = self.reshape2(x, (256 * 256,))
|
||||
|
@ -242,7 +238,6 @@ class ReshapeNet2(nn.Cell):
|
|||
class ReshapeNet3(nn.Cell):
|
||||
def __init__(self, strategy0):
|
||||
super(ReshapeNet3, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy0)
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
|
||||
|
@ -251,7 +246,6 @@ class ReshapeNet3(nn.Cell):
|
|||
self.reshape3 = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.virtual_dataset(x)
|
||||
x = self.reshape(x, (256, 25088))
|
||||
x = self.matmul(x, self.matmul_weight)
|
||||
x = self.reshape2(x, (256 * 256,))
|
||||
|
@ -263,14 +257,12 @@ class ReshapeNet3(nn.Cell):
|
|||
class ReshapeNet4(nn.Cell):
|
||||
def __init__(self, strategy0):
|
||||
super(ReshapeNet4, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.reshape = P.Reshape()
|
||||
self.reshape2 = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy0)
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.virtual_dataset(x)
|
||||
x = self.reshape(x, (256, 25088))
|
||||
w = self.reshape2(self.matmul_weight, (25088, 256))
|
||||
x = self.matmul(x, w)
|
||||
|
@ -280,14 +272,12 @@ class ReshapeNet4(nn.Cell):
|
|||
class ReshapeNet5(nn.Cell):
|
||||
def __init__(self, strategy0):
|
||||
super(ReshapeNet5, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul1 = P.MatMul().shard(strategy0)
|
||||
self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
|
||||
self.matmul2 = P.MatMul().shard(strategy0)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.virtual_dataset(x)
|
||||
x = self.reshape(x, (256, 25088))
|
||||
matmul1_o = self.matmul1(x, self.matmul1_weight)
|
||||
matmul2_o = self.matmul2(matmul1_o, x)
|
||||
|
@ -297,7 +287,6 @@ class ReshapeNet5(nn.Cell):
|
|||
class ReshapeNet6(nn.Cell):
|
||||
def __init__(self, strategy0):
|
||||
super(ReshapeNet6, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul1_1 = P.MatMul().shard(strategy0)
|
||||
self.matmul1_2 = P.MatMul().shard(strategy0)
|
||||
|
@ -306,7 +295,6 @@ class ReshapeNet6(nn.Cell):
|
|||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.virtual_dataset(x)
|
||||
x = self.reshape(x, (256, 25088))
|
||||
matmul1_1_o = self.matmul1_1(x, self.matmul1_weight)
|
||||
matmul1_2_o = self.matmul1_2(x, self.matmul1_weight)
|
||||
|
@ -334,32 +322,32 @@ def reshape_net2(backbone):
|
|||
|
||||
|
||||
def test_reshape_net1_1():
|
||||
reshape_net2(ReshapeNet1(((1, 8), (8, 1))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_net1_2():
|
||||
reshape_net2(ReshapeNet1(((1, 8), (8, 2))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
def test_reshape_net2_1():
|
||||
reshape_net2(ReshapeNet2(((1, 8), (8, 1))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_net2_2():
|
||||
reshape_net2(ReshapeNet2(((1, 8), (8, 2))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
def test_reshape_net3_1():
|
||||
reshape_net2(ReshapeNet3(((1, 8), (8, 1))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_net3_2():
|
||||
reshape_net2(ReshapeNet3(((1, 8), (8, 2))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
def test_reshape_net4_1():
|
||||
try:
|
||||
reshape_net2(ReshapeNet4(((1, 8), (8, 1))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 1)))))
|
||||
except ValueError:
|
||||
pass
|
||||
except TypeError:
|
||||
|
@ -370,7 +358,7 @@ def test_reshape_net4_1():
|
|||
|
||||
def test_reshape_net4_2():
|
||||
try:
|
||||
reshape_net2(ReshapeNet4(((1, 8), (8, 2))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2)))))
|
||||
except ValueError:
|
||||
pass
|
||||
except TypeError:
|
||||
|
@ -380,19 +368,19 @@ def test_reshape_net4_2():
|
|||
|
||||
|
||||
def test_reshape_net5_1():
|
||||
reshape_net2(ReshapeNet5(((1, 8), (8, 1))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_net5_2():
|
||||
reshape_net2(ReshapeNet5(((1, 8), (8, 2))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
def test_reshape_net6_1():
|
||||
reshape_net2(ReshapeNet6(((1, 8), (8, 1))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_net6_2():
|
||||
reshape_net2(ReshapeNet6(((1, 8), (8, 2))))
|
||||
reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
|
@ -453,39 +441,37 @@ def reshape_common2(parallel_mode, net):
|
|||
|
||||
|
||||
def test_reshape_common2_0():
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 1))))
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_common2_1():
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 2))))
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
def test_reshape_common2_2():
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 1))))
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_common2_3():
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 2))))
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
def test_reshape_common2_4():
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 1))))
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1)))))
|
||||
|
||||
|
||||
def test_reshape_common2_5():
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 2))))
|
||||
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
|
||||
|
||||
|
||||
class BatchNormReshapeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BatchNormReshapeNet, self).__init__()
|
||||
self.vd = P._VirtualDataset()
|
||||
self.batch_norm = nn.BatchNorm1d(512, affine=False)
|
||||
self.reshape = P.Reshape()
|
||||
self.prelu = nn.PReLU(channel=256)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.vd(x)
|
||||
x = self.batch_norm(x)
|
||||
x = self.reshape(x, (512, 256))
|
||||
x = self.prelu(x)
|
||||
|
@ -499,7 +485,7 @@ def test_batchnorm_reshape_train():
|
|||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01)
|
||||
|
||||
net = GradWrap(NetWithLoss(BatchNormReshapeNet()))
|
||||
net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet())))
|
||||
|
||||
compile_net(net, input_)
|
||||
|
||||
|
|
Loading…
Reference in New Issue