fix ctpn backbone vgg16 init

This commit is contained in:
qujianwei 2021-04-08 19:41:04 +08:00
parent 20fa7fa276
commit 7f3f3adf4b
2 changed files with 24 additions and 24 deletions

View File

@ -171,11 +171,23 @@ The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. T
...
from src.vgg16 import VGG16
...
network = VGG16()
network = VGG16(num_classes=cfg.num_classes)
...
```
To train a better model, you can modify some parameter in modelzoo/official/cv/vgg16/src/config.py, here we suggested you modify the "warmup_epochs" just like below, you can also try to adjust other parameter.
```python
imagenet_cfg = edict({
...
"warmup_epochs": 5
...
})
```
Then you can train it with ImageNet2012.
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.

View File

@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
pad_mode=pad_mode, has_bias=False)
if not weights_update:
conv.weight.requires_grad = False
layers += [conv]
layers += [_BatchNorm2dInit(out_channels)]
layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)
def _fc(in_channels, out_channels):
'''full connection layer'''
weight = _weight_variable((out_channels, in_channels))
bias = _weight_variable((out_channels,))
return nn.Dense(in_channels, out_channels, weight, bias)
return nn.Dense(in_channels, out_channels)
class VGG16FeatureExtraction(nn.Cell):
@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell):
self.relu = nn.ReLU()
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
self.fc2 = _fc(in_channels=4096, out_channels=4096)
self.batch_size = 32
self.reshape = P.Reshape()
self.dropout = nn.Dropout(0.5)
def construct(self, x):
"""
:param x: shape=(B, 512, 7, 7)
:return:
"""
x = self.reshape(x, (self.batch_size, 7*7*512))
x = self.reshape(x, (-1, 7*7*512))
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.relu(x)
x = self.dropout(x)
return x
class VGG16(nn.Cell):
def __init__(self):
def __init__(self, num_classes):
"""VGG16 construct for training backbone"""
super(VGG16, self).__init__()
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = VGG16Classfier()
self.fc3 = _fc(in_channels=4096, out_channels=1000)
self.fc3 = _fc(in_channels=4096, out_channels=num_classes)
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return: logits, shape=(B, 1000)
"""
feature_maps = self.feature_extraction(x)
feature_maps = self.vgg16_feature_extractor(x)
x = self.max_pool(feature_maps)
x = self.classifier(x)
x = self.fc3(x)