forked from mindspore-Ecosystem/mindspore
!20993 FCN8S add auto parallel mode
Merge pull request !20993 from yao_yf/FCN8S_strategy_set
This commit is contained in:
commit
615b33e970
|
@ -30,9 +30,10 @@ Status RedistributionLayoutTransfer::CheckValidTransfer() {
|
|||
bool not_all_repeat = std::any_of(from_map.begin(), from_map.end(), [](int64_t i) { return i != -1; }) ||
|
||||
std::any_of(to_map.begin(), to_map.end(), [](int64_t i) { return i != -1; });
|
||||
if (from_in_ != to_in_ && not_all_repeat) {
|
||||
MS_LOG(ERROR) << "In dynamic shape scene, the from_tensor_shape should be equal to to_tensor_shape";
|
||||
MS_LOG(ERROR) << "from_in layout" << from_in_.ToString();
|
||||
MS_LOG(ERROR) << "to_in layout" << to_in_.ToString();
|
||||
MS_LOG(ERROR) << "In dynamic shape scene, the layout between the neighboring node should be equal "
|
||||
"in avoid to insert redistribution operators";
|
||||
MS_LOG(ERROR) << "from layout" << from_in_.ToString();
|
||||
MS_LOG(ERROR) << "to layout" << to_in_.ToString();
|
||||
return Status::FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ image_std: [57.375, 57.120, 58.395]
|
|||
ignore_label: 255
|
||||
num_classes: 21
|
||||
model: "FCN8s"
|
||||
parallel_mode: "data_parallel"
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
|
|
|
@ -22,161 +22,165 @@ class FCN8s(nn.Cell):
|
|||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.conv1 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=3,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=3, out_channels=64,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=64, out_channels=64,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv2 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=64, out_channels=128,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=128, out_channels=128,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv3 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=128,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=128, out_channels=256,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=256, out_channels=256,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=256, out_channels=256,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv4 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=256, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv5 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=512, out_channels=512,
|
||||
kernel_size=3, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv6 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=4096,
|
||||
kernel_size=7,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=512, out_channels=4096,
|
||||
kernel_size=7, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.conv7 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=4096,
|
||||
out_channels=4096,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.Conv2d(in_channels=4096, out_channels=4096,
|
||||
kernel_size=1, weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.score_fr = nn.Conv2d(in_channels=4096,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.score_pool4 = nn.Conv2d(in_channels=512,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.score_pool3 = nn.Conv2d(in_channels=256,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=16,
|
||||
stride=8,
|
||||
weight_init='xavier_uniform')
|
||||
self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
|
||||
kernel_size=1, weight_init='xavier_uniform')
|
||||
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
|
||||
kernel_size=4, stride=2, weight_init='xavier_uniform')
|
||||
self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
|
||||
kernel_size=1, weight_init='xavier_uniform')
|
||||
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
|
||||
kernel_size=4, stride=2, weight_init='xavier_uniform')
|
||||
self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
|
||||
kernel_size=1, weight_init='xavier_uniform')
|
||||
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
|
||||
kernel_size=16, stride=8, weight_init='xavier_uniform')
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def set_model_parallel_shard_strategy(self, device_num):
|
||||
self.conv2d_strategy = ((1, 1, 1, device_num), (1, 1, 1, 1))
|
||||
self.bn_strategy = ((1, 1, 1, device_num), (1,), (1,), (1,), (1,))
|
||||
self.relu_strategy = ((1, 1, 1, device_num),)
|
||||
self.maxpool_strategy = ((1, 1, 1, device_num),)
|
||||
self.add_strategy = ((1, 1, 1, device_num), (1, 1, 1, device_num))
|
||||
|
||||
self.conv1.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv1.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv1.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv1.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv1.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv1.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.pool1.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv2.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv2.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv2.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv2.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv2.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv2.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.pool2.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv3.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv3.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv3.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv3.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv3.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv3.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.conv3.cell_list[6].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv3.cell_list[7].bn_train.shard(self.bn_strategy)
|
||||
self.conv3.cell_list[8].relu.shard(self.relu_strategy)
|
||||
self.pool3.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv4.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv4.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv4.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv4.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv4.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv4.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.conv4.cell_list[6].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv4.cell_list[7].bn_train.shard(self.bn_strategy)
|
||||
self.conv4.cell_list[8].relu.shard(self.relu_strategy)
|
||||
self.pool4.max_pool.shard(self.maxpool_strategy)
|
||||
self.conv5.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv5.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv5.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv5.cell_list[3].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv5.cell_list[4].bn_train.shard(self.bn_strategy)
|
||||
self.conv5.cell_list[5].relu.shard(self.relu_strategy)
|
||||
self.conv5.cell_list[6].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv5.cell_list[7].bn_train.shard(self.bn_strategy)
|
||||
self.conv5.cell_list[8].relu.shard(self.relu_strategy)
|
||||
self.pool5.max_pool.shard(((1, 1, 1, device_num),))
|
||||
self.conv6.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv6.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv6.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.conv7.cell_list[0].conv2d.shard(self.conv2d_strategy)
|
||||
self.conv7.cell_list[1].bn_train.shard(self.bn_strategy)
|
||||
self.conv7.cell_list[2].relu.shard(self.relu_strategy)
|
||||
self.score_fr.conv2d.shard(self.conv2d_strategy)
|
||||
self.upscore2.conv2d_transpose.shard(self.conv2d_strategy)
|
||||
self.score_pool4.conv2d.shard(self.conv2d_strategy)
|
||||
self.upscore_pool4.conv2d_transpose.shard(self.conv2d_strategy)
|
||||
self.score_pool3.conv2d.shard(self.conv2d_strategy)
|
||||
self.upscore8.conv2d_transpose.shard(self.conv2d_strategy)
|
||||
self.add1.shard(self.add_strategy)
|
||||
self.add2.shard(self.add_strategy)
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.conv1(x)
|
||||
p1 = self.pool1(x1)
|
||||
|
|
|
@ -51,6 +51,8 @@ def train():
|
|||
config.group_size = 1
|
||||
if device_num > 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
if config.parallel_mode in ParallelMode.MODE_LIST:
|
||||
parallel_mode = config.parallel_mode
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||
init()
|
||||
config.rank = get_rank()
|
||||
|
@ -73,6 +75,9 @@ def train():
|
|||
dataset = dataset.get_dataset(repeat=1)
|
||||
|
||||
net = FCN8s(n_class=config.num_classes)
|
||||
if context.get_auto_parallel_context("parallel_mode") in [ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL]:
|
||||
net.set_model_parallel_shard_strategy(device_num)
|
||||
loss_ = loss.SoftmaxCrossEntropyLoss(config.num_classes, config.ignore_label)
|
||||
|
||||
# load pretrained vgg16 parameters to init FCN8s
|
||||
|
|
Loading…
Reference in New Issue