!20993 FCN8S add auto parallel mode

Merge pull request !20993 from yao_yf/FCN8S_strategy_set
This commit is contained in:
i-robot 2021-07-29 01:59:01 +00:00 committed by Gitee
commit 615b33e970
4 changed files with 118 additions and 107 deletions

View File

@ -30,9 +30,10 @@ Status RedistributionLayoutTransfer::CheckValidTransfer() {
bool not_all_repeat = std::any_of(from_map.begin(), from_map.end(), [](int64_t i) { return i != -1; }) ||
std::any_of(to_map.begin(), to_map.end(), [](int64_t i) { return i != -1; });
if (from_in_ != to_in_ && not_all_repeat) {
MS_LOG(ERROR) << "In dynamic shape scene, the from_tensor_shape should be equal to to_tensor_shape";
MS_LOG(ERROR) << "from_in layout" << from_in_.ToString();
MS_LOG(ERROR) << "to_in layout" << to_in_.ToString();
MS_LOG(ERROR) << "In dynamic shape scene, the layout between the neighboring node should be equal "
"in avoid to insert redistribution operators";
MS_LOG(ERROR) << "from layout" << from_in_.ToString();
MS_LOG(ERROR) << "to layout" << to_in_.ToString();
return Status::FAILED;
}
}

View File

@ -20,6 +20,7 @@ image_std: [57.375, 57.120, 58.395]
ignore_label: 255
num_classes: 21
model: "FCN8s"
parallel_mode: "data_parallel"
# ======================================================================================
# Training options

View File

@ -22,161 +22,165 @@ class FCN8s(nn.Cell):
super().__init__()
self.n_class = n_class
self.conv1 = nn.SequentialCell(
nn.Conv2d(in_channels=3,
out_channels=64,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=3, out_channels=64,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(in_channels=64,
out_channels=64,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=64, out_channels=64,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.SequentialCell(
nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=64, out_channels=128,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(in_channels=128,
out_channels=128,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=128, out_channels=128,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.SequentialCell(
nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=128, out_channels=256,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=256, out_channels=256,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=256, out_channels=256,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4 = nn.SequentialCell(
nn.Conv2d(in_channels=256,
out_channels=512,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=256, out_channels=512,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=512, out_channels=512,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=512, out_channels=512,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv5 = nn.SequentialCell(
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=512, out_channels=512,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=512, out_channels=512,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=3,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=512, out_channels=512,
kernel_size=3, weight_init='xavier_uniform'),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv6 = nn.SequentialCell(
nn.Conv2d(in_channels=512,
out_channels=4096,
kernel_size=7,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=512, out_channels=4096,
kernel_size=7, weight_init='xavier_uniform'),
nn.BatchNorm2d(4096),
nn.ReLU(),
)
self.conv7 = nn.SequentialCell(
nn.Conv2d(in_channels=4096,
out_channels=4096,
kernel_size=1,
weight_init='xavier_uniform'),
nn.Conv2d(in_channels=4096, out_channels=4096,
kernel_size=1, weight_init='xavier_uniform'),
nn.BatchNorm2d(4096),
nn.ReLU(),
)
self.score_fr = nn.Conv2d(in_channels=4096,
out_channels=self.n_class,
kernel_size=1,
weight_init='xavier_uniform')
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class,
out_channels=self.n_class,
kernel_size=4,
stride=2,
weight_init='xavier_uniform')
self.score_pool4 = nn.Conv2d(in_channels=512,
out_channels=self.n_class,
kernel_size=1,
weight_init='xavier_uniform')
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class,
out_channels=self.n_class,
kernel_size=4,
stride=2,
weight_init='xavier_uniform')
self.score_pool3 = nn.Conv2d(in_channels=256,
out_channels=self.n_class,
kernel_size=1,
weight_init='xavier_uniform')
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class,
out_channels=self.n_class,
kernel_size=16,
stride=8,
weight_init='xavier_uniform')
self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
kernel_size=1, weight_init='xavier_uniform')
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
kernel_size=4, stride=2, weight_init='xavier_uniform')
self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
kernel_size=1, weight_init='xavier_uniform')
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
kernel_size=4, stride=2, weight_init='xavier_uniform')
self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
kernel_size=1, weight_init='xavier_uniform')
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
kernel_size=16, stride=8, weight_init='xavier_uniform')
self.shape = P.Shape()
self.cast = P.Cast()
def set_model_parallel_shard_strategy(self, device_num):
self.conv2d_strategy = ((1, 1, 1, device_num), (1, 1, 1, 1))
self.bn_strategy = ((1, 1, 1, device_num), (1,), (1,), (1,), (1,))
self.relu_strategy = ((1, 1, 1, device_num),)
self.maxpool_strategy = ((1, 1, 1, device_num),)
self.add_strategy = ((1, 1, 1, device_num), (1, 1, 1, device_num))
self.conv1.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv1.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv1.cell_list[2].relu.shard(self.relu_strategy)
self.conv1.cell_list[3].conv2d.shard(self.conv2d_strategy)
self.conv1.cell_list[4].bn_train.shard(self.bn_strategy)
self.conv1.cell_list[5].relu.shard(self.relu_strategy)
self.pool1.max_pool.shard(self.maxpool_strategy)
self.conv2.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv2.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv2.cell_list[2].relu.shard(self.relu_strategy)
self.conv2.cell_list[3].conv2d.shard(self.conv2d_strategy)
self.conv2.cell_list[4].bn_train.shard(self.bn_strategy)
self.conv2.cell_list[5].relu.shard(self.relu_strategy)
self.pool2.max_pool.shard(self.maxpool_strategy)
self.conv3.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv3.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv3.cell_list[2].relu.shard(self.relu_strategy)
self.conv3.cell_list[3].conv2d.shard(self.conv2d_strategy)
self.conv3.cell_list[4].bn_train.shard(self.bn_strategy)
self.conv3.cell_list[5].relu.shard(self.relu_strategy)
self.conv3.cell_list[6].conv2d.shard(self.conv2d_strategy)
self.conv3.cell_list[7].bn_train.shard(self.bn_strategy)
self.conv3.cell_list[8].relu.shard(self.relu_strategy)
self.pool3.max_pool.shard(self.maxpool_strategy)
self.conv4.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv4.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv4.cell_list[2].relu.shard(self.relu_strategy)
self.conv4.cell_list[3].conv2d.shard(self.conv2d_strategy)
self.conv4.cell_list[4].bn_train.shard(self.bn_strategy)
self.conv4.cell_list[5].relu.shard(self.relu_strategy)
self.conv4.cell_list[6].conv2d.shard(self.conv2d_strategy)
self.conv4.cell_list[7].bn_train.shard(self.bn_strategy)
self.conv4.cell_list[8].relu.shard(self.relu_strategy)
self.pool4.max_pool.shard(self.maxpool_strategy)
self.conv5.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv5.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv5.cell_list[2].relu.shard(self.relu_strategy)
self.conv5.cell_list[3].conv2d.shard(self.conv2d_strategy)
self.conv5.cell_list[4].bn_train.shard(self.bn_strategy)
self.conv5.cell_list[5].relu.shard(self.relu_strategy)
self.conv5.cell_list[6].conv2d.shard(self.conv2d_strategy)
self.conv5.cell_list[7].bn_train.shard(self.bn_strategy)
self.conv5.cell_list[8].relu.shard(self.relu_strategy)
self.pool5.max_pool.shard(((1, 1, 1, device_num),))
self.conv6.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv6.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv6.cell_list[2].relu.shard(self.relu_strategy)
self.conv7.cell_list[0].conv2d.shard(self.conv2d_strategy)
self.conv7.cell_list[1].bn_train.shard(self.bn_strategy)
self.conv7.cell_list[2].relu.shard(self.relu_strategy)
self.score_fr.conv2d.shard(self.conv2d_strategy)
self.upscore2.conv2d_transpose.shard(self.conv2d_strategy)
self.score_pool4.conv2d.shard(self.conv2d_strategy)
self.upscore_pool4.conv2d_transpose.shard(self.conv2d_strategy)
self.score_pool3.conv2d.shard(self.conv2d_strategy)
self.upscore8.conv2d_transpose.shard(self.conv2d_strategy)
self.add1.shard(self.add_strategy)
self.add2.shard(self.add_strategy)
def construct(self, x):
x1 = self.conv1(x)
p1 = self.pool1(x1)

View File

@ -51,6 +51,8 @@ def train():
config.group_size = 1
if device_num > 1:
parallel_mode = ParallelMode.DATA_PARALLEL
if config.parallel_mode in ParallelMode.MODE_LIST:
parallel_mode = config.parallel_mode
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
init()
config.rank = get_rank()
@ -73,6 +75,9 @@ def train():
dataset = dataset.get_dataset(repeat=1)
net = FCN8s(n_class=config.num_classes)
if context.get_auto_parallel_context("parallel_mode") in [ParallelMode.SEMI_AUTO_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
net.set_model_parallel_shard_strategy(device_num)
loss_ = loss.SoftmaxCrossEntropyLoss(config.num_classes, config.ignore_label)
# load pretrained vgg16 parameters to init FCN8s