From df10f0d9b30025611a6fe9268a77a70e8264108a Mon Sep 17 00:00:00 2001 From: yao_yf Date: Wed, 28 Jul 2021 11:36:55 +0800 Subject: [PATCH] fcn8s network support auto parallel --- .../redistribution_layout_transfer.cc | 7 +- .../official/cv/FCN8s/default_config.yaml | 1 + model_zoo/official/cv/FCN8s/src/nets/FCN8s.py | 212 +++++++++--------- model_zoo/official/cv/FCN8s/train.py | 5 + 4 files changed, 118 insertions(+), 107 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc index 293d968c414..d8d09d95228 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc @@ -30,9 +30,10 @@ Status RedistributionLayoutTransfer::CheckValidTransfer() { bool not_all_repeat = std::any_of(from_map.begin(), from_map.end(), [](int64_t i) { return i != -1; }) || std::any_of(to_map.begin(), to_map.end(), [](int64_t i) { return i != -1; }); if (from_in_ != to_in_ && not_all_repeat) { - MS_LOG(ERROR) << "In dynamic shape scene, the from_tensor_shape should be equal to to_tensor_shape"; - MS_LOG(ERROR) << "from_in layout" << from_in_.ToString(); - MS_LOG(ERROR) << "to_in layout" << to_in_.ToString(); + MS_LOG(ERROR) << "In dynamic shape scene, the layout between the neighboring node should be equal " + "in avoid to insert redistribution operators"; + MS_LOG(ERROR) << "from layout" << from_in_.ToString(); + MS_LOG(ERROR) << "to layout" << to_in_.ToString(); return Status::FAILED; } } diff --git a/model_zoo/official/cv/FCN8s/default_config.yaml b/model_zoo/official/cv/FCN8s/default_config.yaml index 592c76d6a6b..8ad0814f343 100644 --- a/model_zoo/official/cv/FCN8s/default_config.yaml +++ b/model_zoo/official/cv/FCN8s/default_config.yaml @@ -20,6 +20,7 @@ image_std: [57.375, 57.120, 58.395] ignore_label: 255 num_classes: 21 model: "FCN8s" +parallel_mode: "data_parallel" # ====================================================================================== # Training options diff --git a/model_zoo/official/cv/FCN8s/src/nets/FCN8s.py b/model_zoo/official/cv/FCN8s/src/nets/FCN8s.py index 033775260fb..31b9a252afc 100644 --- a/model_zoo/official/cv/FCN8s/src/nets/FCN8s.py +++ b/model_zoo/official/cv/FCN8s/src/nets/FCN8s.py @@ -22,161 +22,165 @@ class FCN8s(nn.Cell): super().__init__() self.n_class = n_class self.conv1 = nn.SequentialCell( - nn.Conv2d(in_channels=3, - out_channels=64, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=3, out_channels=64, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(64), nn.ReLU(), - nn.Conv2d(in_channels=64, - out_channels=64, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=64, out_channels=64, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(64), nn.ReLU() ) - self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) - self.conv2 = nn.SequentialCell( - nn.Conv2d(in_channels=64, - out_channels=128, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=64, out_channels=128, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(128), nn.ReLU(), - nn.Conv2d(in_channels=128, - out_channels=128, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=128, out_channels=128, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(128), nn.ReLU() ) - self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) - self.conv3 = nn.SequentialCell( - nn.Conv2d(in_channels=128, - out_channels=256, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=128, out_channels=256, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(256), nn.ReLU(), - nn.Conv2d(in_channels=256, - out_channels=256, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=256, out_channels=256, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(256), nn.ReLU(), - nn.Conv2d(in_channels=256, - out_channels=256, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=256, out_channels=256, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(256), nn.ReLU() ) - self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) - self.conv4 = nn.SequentialCell( - nn.Conv2d(in_channels=256, - out_channels=512, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=256, out_channels=512, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(512), nn.ReLU(), - nn.Conv2d(in_channels=512, - out_channels=512, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=512, out_channels=512, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(512), nn.ReLU(), - nn.Conv2d(in_channels=512, - out_channels=512, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=512, out_channels=512, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(512), nn.ReLU() ) - self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) - self.conv5 = nn.SequentialCell( - nn.Conv2d(in_channels=512, - out_channels=512, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=512, out_channels=512, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(512), nn.ReLU(), - nn.Conv2d(in_channels=512, - out_channels=512, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=512, out_channels=512, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(512), nn.ReLU(), - nn.Conv2d(in_channels=512, - out_channels=512, - kernel_size=3, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=512, out_channels=512, + kernel_size=3, weight_init='xavier_uniform'), nn.BatchNorm2d(512), nn.ReLU() ) - self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) - self.conv6 = nn.SequentialCell( - nn.Conv2d(in_channels=512, - out_channels=4096, - kernel_size=7, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=512, out_channels=4096, + kernel_size=7, weight_init='xavier_uniform'), nn.BatchNorm2d(4096), nn.ReLU(), ) - self.conv7 = nn.SequentialCell( - nn.Conv2d(in_channels=4096, - out_channels=4096, - kernel_size=1, - weight_init='xavier_uniform'), + nn.Conv2d(in_channels=4096, out_channels=4096, + kernel_size=1, weight_init='xavier_uniform'), nn.BatchNorm2d(4096), nn.ReLU(), ) - - self.score_fr = nn.Conv2d(in_channels=4096, - out_channels=self.n_class, - kernel_size=1, - weight_init='xavier_uniform') - - self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, - out_channels=self.n_class, - kernel_size=4, - stride=2, - weight_init='xavier_uniform') - - self.score_pool4 = nn.Conv2d(in_channels=512, - out_channels=self.n_class, - kernel_size=1, - weight_init='xavier_uniform') - - self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, - out_channels=self.n_class, - kernel_size=4, - stride=2, - weight_init='xavier_uniform') - - self.score_pool3 = nn.Conv2d(in_channels=256, - out_channels=self.n_class, - kernel_size=1, - weight_init='xavier_uniform') - - self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, - out_channels=self.n_class, - kernel_size=16, - stride=8, - weight_init='xavier_uniform') + self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class, + kernel_size=1, weight_init='xavier_uniform') + self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, + kernel_size=4, stride=2, weight_init='xavier_uniform') + self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class, + kernel_size=1, weight_init='xavier_uniform') + self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, + kernel_size=4, stride=2, weight_init='xavier_uniform') + self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class, + kernel_size=1, weight_init='xavier_uniform') + self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, + kernel_size=16, stride=8, weight_init='xavier_uniform') self.shape = P.Shape() self.cast = P.Cast() + def set_model_parallel_shard_strategy(self, device_num): + self.conv2d_strategy = ((1, 1, 1, device_num), (1, 1, 1, 1)) + self.bn_strategy = ((1, 1, 1, device_num), (1,), (1,), (1,), (1,)) + self.relu_strategy = ((1, 1, 1, device_num),) + self.maxpool_strategy = ((1, 1, 1, device_num),) + self.add_strategy = ((1, 1, 1, device_num), (1, 1, 1, device_num)) + + self.conv1.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv1.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv1.cell_list[2].relu.shard(self.relu_strategy) + self.conv1.cell_list[3].conv2d.shard(self.conv2d_strategy) + self.conv1.cell_list[4].bn_train.shard(self.bn_strategy) + self.conv1.cell_list[5].relu.shard(self.relu_strategy) + self.pool1.max_pool.shard(self.maxpool_strategy) + self.conv2.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv2.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv2.cell_list[2].relu.shard(self.relu_strategy) + self.conv2.cell_list[3].conv2d.shard(self.conv2d_strategy) + self.conv2.cell_list[4].bn_train.shard(self.bn_strategy) + self.conv2.cell_list[5].relu.shard(self.relu_strategy) + self.pool2.max_pool.shard(self.maxpool_strategy) + self.conv3.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv3.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv3.cell_list[2].relu.shard(self.relu_strategy) + self.conv3.cell_list[3].conv2d.shard(self.conv2d_strategy) + self.conv3.cell_list[4].bn_train.shard(self.bn_strategy) + self.conv3.cell_list[5].relu.shard(self.relu_strategy) + self.conv3.cell_list[6].conv2d.shard(self.conv2d_strategy) + self.conv3.cell_list[7].bn_train.shard(self.bn_strategy) + self.conv3.cell_list[8].relu.shard(self.relu_strategy) + self.pool3.max_pool.shard(self.maxpool_strategy) + self.conv4.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv4.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv4.cell_list[2].relu.shard(self.relu_strategy) + self.conv4.cell_list[3].conv2d.shard(self.conv2d_strategy) + self.conv4.cell_list[4].bn_train.shard(self.bn_strategy) + self.conv4.cell_list[5].relu.shard(self.relu_strategy) + self.conv4.cell_list[6].conv2d.shard(self.conv2d_strategy) + self.conv4.cell_list[7].bn_train.shard(self.bn_strategy) + self.conv4.cell_list[8].relu.shard(self.relu_strategy) + self.pool4.max_pool.shard(self.maxpool_strategy) + self.conv5.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv5.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv5.cell_list[2].relu.shard(self.relu_strategy) + self.conv5.cell_list[3].conv2d.shard(self.conv2d_strategy) + self.conv5.cell_list[4].bn_train.shard(self.bn_strategy) + self.conv5.cell_list[5].relu.shard(self.relu_strategy) + self.conv5.cell_list[6].conv2d.shard(self.conv2d_strategy) + self.conv5.cell_list[7].bn_train.shard(self.bn_strategy) + self.conv5.cell_list[8].relu.shard(self.relu_strategy) + self.pool5.max_pool.shard(((1, 1, 1, device_num),)) + self.conv6.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv6.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv6.cell_list[2].relu.shard(self.relu_strategy) + self.conv7.cell_list[0].conv2d.shard(self.conv2d_strategy) + self.conv7.cell_list[1].bn_train.shard(self.bn_strategy) + self.conv7.cell_list[2].relu.shard(self.relu_strategy) + self.score_fr.conv2d.shard(self.conv2d_strategy) + self.upscore2.conv2d_transpose.shard(self.conv2d_strategy) + self.score_pool4.conv2d.shard(self.conv2d_strategy) + self.upscore_pool4.conv2d_transpose.shard(self.conv2d_strategy) + self.score_pool3.conv2d.shard(self.conv2d_strategy) + self.upscore8.conv2d_transpose.shard(self.conv2d_strategy) + self.add1.shard(self.add_strategy) + self.add2.shard(self.add_strategy) + def construct(self, x): x1 = self.conv1(x) p1 = self.pool1(x1) diff --git a/model_zoo/official/cv/FCN8s/train.py b/model_zoo/official/cv/FCN8s/train.py index e92913915b4..072fddda9dc 100644 --- a/model_zoo/official/cv/FCN8s/train.py +++ b/model_zoo/official/cv/FCN8s/train.py @@ -51,6 +51,8 @@ def train(): config.group_size = 1 if device_num > 1: parallel_mode = ParallelMode.DATA_PARALLEL + if config.parallel_mode in ParallelMode.MODE_LIST: + parallel_mode = config.parallel_mode context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num) init() config.rank = get_rank() @@ -73,6 +75,9 @@ def train(): dataset = dataset.get_dataset(repeat=1) net = FCN8s(n_class=config.num_classes) + if context.get_auto_parallel_context("parallel_mode") in [ParallelMode.SEMI_AUTO_PARALLEL, + ParallelMode.AUTO_PARALLEL]: + net.set_model_parallel_shard_strategy(device_num) loss_ = loss.SoftmaxCrossEntropyLoss(config.num_classes, config.ignore_label) # load pretrained vgg16 parameters to init FCN8s