forked from mindspore-Ecosystem/mindspore
fix ctpn backbone vgg16 init
This commit is contained in:
parent
20fa7fa276
commit
7f3f3adf4b
|
@ -171,11 +171,23 @@ The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. T
|
|||
...
|
||||
from src.vgg16 import VGG16
|
||||
...
|
||||
network = VGG16()
|
||||
network = VGG16(num_classes=cfg.num_classes)
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
To train a better model, you can modify some parameter in modelzoo/official/cv/vgg16/src/config.py, here we suggested you modify the "warmup_epochs" just like below, you can also try to adjust other parameter.
|
||||
|
||||
```python
|
||||
|
||||
imagenet_cfg = edict({
|
||||
...
|
||||
"warmup_epochs": 5
|
||||
...
|
||||
})
|
||||
|
||||
```
|
||||
|
||||
Then you can train it with ImageNet2012.
|
||||
> Notes:
|
||||
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
|
|
|
@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01):
|
|||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
|
||||
"""Batchnorm2D wrapper."""
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
|
||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
||||
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
|
||||
"""Conv2D wrapper."""
|
||||
weights = 'ones'
|
||||
layers = []
|
||||
conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
||||
pad_mode=pad_mode, has_bias=False)
|
||||
if not weights_update:
|
||||
conv.weight.requires_grad = False
|
||||
layers += [conv]
|
||||
layers += [_BatchNorm2dInit(out_channels)]
|
||||
layers += [nn.BatchNorm2d(out_channels)]
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def _fc(in_channels, out_channels):
|
||||
'''full connection layer'''
|
||||
weight = _weight_variable((out_channels, in_channels))
|
||||
bias = _weight_variable((out_channels,))
|
||||
return nn.Dense(in_channels, out_channels, weight, bias)
|
||||
return nn.Dense(in_channels, out_channels)
|
||||
|
||||
|
||||
class VGG16FeatureExtraction(nn.Cell):
|
||||
|
@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell):
|
|||
self.relu = nn.ReLU()
|
||||
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
|
||||
self.fc2 = _fc(in_channels=4096, out_channels=4096)
|
||||
self.batch_size = 32
|
||||
self.reshape = P.Reshape()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 512, 7, 7)
|
||||
:return:
|
||||
"""
|
||||
x = self.reshape(x, (self.batch_size, 7*7*512))
|
||||
x = self.reshape(x, (-1, 7*7*512))
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class VGG16(nn.Cell):
|
||||
def __init__(self):
|
||||
def __init__(self, num_classes):
|
||||
"""VGG16 construct for training backbone"""
|
||||
super(VGG16, self).__init__()
|
||||
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True)
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.classifier = VGG16Classfier()
|
||||
self.fc3 = _fc(in_channels=4096, out_channels=1000)
|
||||
self.fc3 = _fc(in_channels=4096, out_channels=num_classes)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return: logits, shape=(B, 1000)
|
||||
"""
|
||||
feature_maps = self.feature_extraction(x)
|
||||
feature_maps = self.vgg16_feature_extractor(x)
|
||||
x = self.max_pool(feature_maps)
|
||||
x = self.classifier(x)
|
||||
x = self.fc3(x)
|
||||
|
|
Loading…
Reference in New Issue